pytorch/caffe2/core/observer_test.cc
Sebastian Messmer 643ca5def2 Replace c10::guts::stuff with std::stuff (#30915)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30915

Since we now have C++14, we don't need these c10::guts helpers anymore
ghstack-source-id: 95777609

Test Plan: waitforsandcastle

Differential Revision: D18869639

fbshipit-source-id: 97716f932297c64c6e814410ac47b444c33d4e2e
2019-12-16 13:57:19 -08:00

182 lines
4.9 KiB
C++

#include <gtest/gtest.h>
#include "c10/util/Registry.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/net_simple.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/scope_guard.h"
namespace caffe2 {
namespace {
static std::atomic<int> counter;
template <class T>
class DummyObserver final : public ObserverBase<T> {
public:
explicit DummyObserver<T>(T* subject_) : ObserverBase<T>(subject_) {}
void Start() override;
void Stop() override;
~DummyObserver() override {}
};
template <>
void DummyObserver<NetBase>::Start() {
vector<OperatorBase*> operators = subject_->GetOperators();
for (auto& op : operators) {
op->AttachObserver(std::make_unique<DummyObserver<OperatorBase>>(op));
}
counter.fetch_add(1000);
}
template <>
void DummyObserver<OperatorBase>::Start() {
counter.fetch_add(100);
}
template <>
void DummyObserver<NetBase>::Stop() {
counter.fetch_add(10);
}
template <>
void DummyObserver<OperatorBase>::Stop() {
counter.fetch_add(1);
}
class ObsTestDummyOp final : public OperatorBase {
public:
using OperatorBase::OperatorBase;
bool Run(int /* unused */) override {
StartAllObservers();
StopAllObservers();
return true;
}
};
REGISTER_CPU_OPERATOR(ObsTestDummy, ObsTestDummyOp);
REGISTER_CUDA_OPERATOR(ObsTestDummy, ObsTestDummyOp);
OPERATOR_SCHEMA(ObsTestDummy)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
unique_ptr<NetBase> CreateNetTestHelper(Workspace* ws, bool isDAG = false) {
NetDef net_def;
if (isDAG) {
net_def.set_type("dag");
}
{
auto& op = *(net_def.add_op());
op.set_type("ObsTestDummy");
op.add_input("in");
op.add_output("hidden");
}
{
auto& op = *(net_def.add_op());
op.set_type("ObsTestDummy");
op.add_input("hidden");
op.add_output("out");
}
net_def.add_external_input("in");
net_def.add_external_output("out");
return CreateNet(net_def, ws);
}
}
TEST(ObserverTest, TestNotify) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
net.get()->AttachObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
TEST(ObserverTest, TestUniqueMap) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get()), net.get());
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
auto* ref = net.get()->AttachObserver(std::move(net_ob));
net.get()->Run();
unique_ptr<Observable<NetBase>::Observer> test =
net.get()->DetachObserver(ref);
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
TEST(ObserverTest, TestNotifyAfterDetach) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
auto* ob = net.get()->AttachObserver(std::move(net_ob));
net.get()->DetachObserver(ob);
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(0, count_after - count_before);
}
TEST(ObserverTest, TestDAGNetBase) {
auto count_before = counter.load();
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws, true));
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
net.get()->AttachObserver(std::move(net_ob));
net.get()->Run();
auto count_after = counter.load();
EXPECT_EQ(1212, count_after - count_before);
}
#if 0
// This test intermittently segfaults,
// see https://github.com/pytorch/pytorch/issues/9137
TEST(ObserverTest, TestMultipleNetBase) {
Workspace ws;
ws.CreateBlob("in");
NetDef net_def;
unique_ptr<NetBase> net(CreateNetTestHelper(&ws, true));
EXPECT_EQ(caffe2::dynamic_cast_if_rtti<NetBase*>(net.get()), net.get());
// There may be some default observers
const size_t prev_num = net.get()->NumObservers();
const int num_tests = 100;
vector<const Observable<NetBase>::Observer*> observers;
for (int i = 0; i < num_tests; ++i) {
unique_ptr<DummyObserver<NetBase>> net_ob =
make_unique<DummyObserver<NetBase>>(net.get());
observers.emplace_back(net.get()->AttachObserver(std::move(net_ob)));
}
net.get()->Run();
for (const auto& observer : observers) {
net.get()->DetachObserver(observer);
}
EXPECT_EQ(net.get()->NumObservers(), prev_num);
}
#endif
} // namespace caffe2