mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
TimeObserver for SimpleNet, an example usage of Observers.
Summary: Implemented TimeObserver for SimpleNet. Reviewed By: pietern Differential Revision: D5188373 fbshipit-source-id: 530d75d176aa29d38c131338c3a2be70bc221a47
This commit is contained in:
parent
d3ec6e8f55
commit
a2521148b4
6 changed files with 178 additions and 1 deletions
|
|
@ -45,6 +45,7 @@ option(USE_ROCKSDB "Use RocksDB" ON)
|
|||
option(USE_REDIS "Use Redis" OFF)
|
||||
option(USE_MPI "Use MPI" ON)
|
||||
option(USE_GLOO "Use Gloo" ON)
|
||||
option(USE_OBSERVERS "Use Observer Library" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build libcaffe2.so" ON)
|
||||
option(USE_OPENMP "Use OpenMP for parallel code" ON)
|
||||
option(BUILD_PYTHON "Build Python binaries" ON)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
add_subdirectory(gloo)
|
||||
add_subdirectory(nccl)
|
||||
add_subdirectory(nnpack)
|
||||
|
||||
add_subdirectory(observers)
|
||||
# Finally pass the src lists back to the parent
|
||||
|
||||
# CPU source, test sources, binary sources
|
||||
|
|
|
|||
9
caffe2/contrib/observers/CMakeLists.txt
Normal file
9
caffe2/contrib/observers/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
if(USE_OBSERVERS)
|
||||
message(STATUS "Include Observer library")
|
||||
set(Caffe2_CONTRIB_OBSERVERS_CPU_SRC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/time_observer.cc"
|
||||
)
|
||||
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${Caffe2_CONTRIB_OBSERVERS_CPU_SRC})
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
endif()
|
||||
41
caffe2/contrib/observers/time_observer.cc
Normal file
41
caffe2/contrib/observers/time_observer.cc
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#include "caffe2/contrib/observers/time_observer.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
bool TimeObserver<SimpleNet>::Start() {
|
||||
vector<OperatorBase*> operators = subject.getOperators();
|
||||
for (auto& op : operators) {
|
||||
children_.push_back(caffe2::make_unique<TimeObserver<OperatorBase>>(*op));
|
||||
}
|
||||
start_time_ = timer_.MilliSeconds();
|
||||
++iterations_;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TimeObserver<SimpleNet>::Stop() {
|
||||
double current_run = timer_.MilliSeconds() - start_time_;
|
||||
total_time_ += current_run;
|
||||
VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TimeObserver<OperatorBase>::Start() {
|
||||
start_time_ = timer_.MilliSeconds();
|
||||
++iterations_;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TimeObserver<OperatorBase>::Stop() {
|
||||
double current_run = timer_.MilliSeconds() - start_time_;
|
||||
total_time_ += current_run;
|
||||
VLOG(1) << "This operator iteration took " << current_run
|
||||
<< " ms to complete.\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
43
caffe2/contrib/observers/time_observer.h
Normal file
43
caffe2/contrib/observers/time_observer.h
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#ifndef CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
|
||||
#define CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/observer.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/timer.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class T>
|
||||
class TimeObserver final : public ObserverBase<T> {
|
||||
public:
|
||||
explicit TimeObserver<T>(T& subject) : ObserverBase<T>(subject) {}
|
||||
inline float average_time() const {
|
||||
return total_time_ / iterations_;
|
||||
}
|
||||
float average_time_children() const {
|
||||
float sum = 0.0f;
|
||||
for (auto& ob : children_) {
|
||||
sum += ob.get()->average_time();
|
||||
}
|
||||
return sum / children_.size();
|
||||
}
|
||||
~TimeObserver() {}
|
||||
|
||||
private:
|
||||
Timer timer_;
|
||||
float start_time_ = 0.0f;
|
||||
float total_time_ = 0.0f;
|
||||
int iterations_ = 0;
|
||||
|
||||
vector<unique_ptr<TimeObserver<OperatorBase>>> children_;
|
||||
|
||||
bool Start() override;
|
||||
bool Stop() override;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
|
||||
83
caffe2/contrib/observers/time_observer_test.cc
Normal file
83
caffe2/contrib/observers/time_observer_test.cc
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#include "caffe2/contrib/observers/time_observer.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/observer.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
class SleepOp final : public OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
bool Run(int /* unused */) override {
|
||||
if (observer_) {
|
||||
observer_->Start();
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(3000));
|
||||
if (observer_) {
|
||||
observer_->Stop();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CPU_OPERATOR(SleepOp, SleepOp);
|
||||
REGISTER_CUDA_OPERATOR(SleepOp, SleepOp);
|
||||
|
||||
OPERATOR_SCHEMA(SleepOp)
|
||||
.NumInputs(0, INT_MAX)
|
||||
.NumOutputs(0, INT_MAX)
|
||||
.AllowInplace({{0, 0}, {1, 1}});
|
||||
|
||||
const std::basic_string<char> kExampleNetDefString = {
|
||||
" name: \"example\""
|
||||
" op {"
|
||||
" input: \"in\""
|
||||
" output: \"hidden\""
|
||||
" type: \"SleepOp\""
|
||||
" }"
|
||||
" op {"
|
||||
" input: \"hidden\""
|
||||
" output: \"out\""
|
||||
" type: \"SleepOp\""
|
||||
" }"};
|
||||
|
||||
unique_ptr<NetBase> CreateNetTestHelper(
|
||||
Workspace* ws,
|
||||
const vector<string>& input,
|
||||
const vector<string>& output) {
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(google::protobuf::TextFormat::ParseFromString(
|
||||
kExampleNetDefString, &net_def));
|
||||
for (const auto& name : input) {
|
||||
net_def.add_external_input(name);
|
||||
}
|
||||
for (const auto& name : output) {
|
||||
net_def.add_external_output(name);
|
||||
}
|
||||
return CreateNet(net_def, ws);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TimeObserverTest, Test3Seconds) {
|
||||
Workspace ws;
|
||||
ws.CreateBlob("in");
|
||||
NetDef net_def;
|
||||
unique_ptr<NetBase> net(CreateNetTestHelper(&ws, {"in"}, {"out"}));
|
||||
unique_ptr<TimeObserver<SimpleNet>> net_ob =
|
||||
make_unique<TimeObserver<SimpleNet>>(
|
||||
*(caffe2::dynamic_cast_if_rtti<SimpleNet*>(net.get())));
|
||||
net.get()->Run();
|
||||
CAFFE_ENFORCE(net_ob.get()->average_time_children() > 3000);
|
||||
CAFFE_ENFORCE(net_ob.get()->average_time_children() < 3500);
|
||||
CAFFE_ENFORCE(net_ob.get()->average_time() > 6000);
|
||||
CAFFE_ENFORCE(net_ob.get()->average_time() < 6500);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue