pytorch/caffe2/core/observer_test.cc
Yangqing Jia 65112f3865 code cleanup: separate the several net implementations to separate files.
Summary: TSIA.

Reviewed By: harouwu

Differential Revision: D5670906

fbshipit-source-id: 507e789978144341bf696fb20dc11f3c2d55493b
2017-08-21 22:07:48 -07:00

144 lines
3.6 KiB
C++

#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/net_dag.h"
#include "caffe2/core/net_simple.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/registry.h"
#include "caffe2/core/scope_guard.h"
namespace caffe2 {
namespace {
static std::atomic<int> counter;
template <class T>
class DummyObserver final : public ObserverBase<T> {
public:
explicit DummyObserver<T>(T* subject_) : ObserverBase<T>(subject_) {}
bool Start() override;
bool Stop() override;
~DummyObserver() {}
};
template <>
bool DummyObserver<NetBase>::Start() {
vector<OperatorBase*> operators = subject_->GetOperators();
for (auto& op : operators) {
op->SetObserver(caffe2::make_unique<DummyObserver<OperatorBase>>(op));
}
counter.fetch_add(1000);
return true;
}
template <>
bool DummyObserver<OperatorBase>::Start() {
counter.fetch_add(100);
return true;
}
template <>
bool DummyObserver<NetBase>::Stop() {
counter.fetch_add(10);
return true;
}
template <>
bool DummyObserver<OperatorBase>::Stop() {
counter.fetch_add(1);
return true;
}
class ObsTestDummyOp final : public OperatorBase {
public:
using OperatorBase::OperatorBase;
bool Run(int /* unused */) override {
if (observer_)
observer_->Start();
if (observer_)
observer_->Stop();
return true;
}
};
REGISTER_CPU_OPERATOR(ObsTestDummy, ObsTestDummyOp);
REGISTER_CUDA_OPERATOR(ObsTestDummy, ObsTestDummyOp);
OPERATOR_SCHEMA(ObsTestDummy)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
unique_ptr<NetBase> CreateNetTestHelper(Workspace* ws, bool isDAG = false) {
NetDef net_def;
if (isDAG) {
net_def.set_type("dag");
}
{
auto& op = *(net_def.add_op());
op.set_type("ObsTestDummy");
op.add_input("in");
op.add_output("hidden");
}
{
auto& op = *(net_def.add_op());
op.set_type("ObsTestDummy");
op.add_input("hidden");
op.add_output("out");
}
net_def.add_external_input("in");
net_def.add_external_output("out");
return CreateNet(net_def, ws);
}
}
TEST(ObserverTest, TestNotify) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
net.get()->SetObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
TEST(ObserverTest, TestNotifyAfterDetach) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
net.get()->SetObserver(std::move(net_ob));
net.get()->RemoveObserver();
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(0, count_after - count_before);
}
TEST(ObserverTest, TestDAGNetBase) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws, true));
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<DAGNetBase*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
net.get()->SetObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
}