diff --git a/caffe2/core/prof_dag_counters.cc b/caffe2/core/prof_dag_counters.cc index 6b576d3d722..0187a854e59 100644 --- a/caffe2/core/prof_dag_counters.cc +++ b/caffe2/core/prof_dag_counters.cc @@ -1,4 +1,5 @@ #include "caffe2/core/prof_dag_counters.h" +#include "caffe2/utils/string_utils.h" #include #include @@ -10,8 +11,20 @@ ProfDAGCounters::ProfDAGCounters(const std::shared_ptr& net_def) { report_.num_runs_ = 0; auto num_ops = net_def->op_size(); report_.op_types_.reserve(num_ops); + report_.op_extra_info_.reserve(num_ops); + for (auto op_id = 0; op_id < num_ops; ++op_id) { report_.op_types_.push_back(net_def->op(op_id).type()); + vector op_extra_info; + if (net_def->op(op_id).has_device_option() && + net_def->op(op_id).device_option().extra_info_size() > 0) { + for (auto i = 0; i < net_def->op(op_id).device_option().extra_info_size(); + ++i) { + auto extra_info_str = net_def->op(op_id).device_option().extra_info(i); + op_extra_info.push_back(extra_info_str); + } + } + report_.op_extra_info_.push_back(op_extra_info); } report_.time_per_op_total_.resize(num_ops); } @@ -99,12 +112,16 @@ ProfDAGReport ProfDAGCounters::GetReport() const { ProfDAGProto ProfDAGReport::statsProto( const std::string& name, - const ProfDAGStats& stats) const { + const ProfDAGStats& stats, + const std::vector& op_extra_info) const { ProfDAGProto stats_proto; const auto& moments = stats.computeMoments(); stats_proto.set_mean(moments.first); stats_proto.set_stddev(moments.second); stats_proto.set_name(name); + for (auto& extra_info : op_extra_info) { + stats_proto.add_extra_info(extra_info); + } return stats_proto; } @@ -114,7 +131,7 @@ ProfDAGProtos ProfDAGReport::GetOperatorStats() const { if (num_runs_ > 1) { for (auto& item : time_per_op_type_total_) { auto buf = prof_dag_protos.add_stats(); - buf->CopyFrom(statsProto(item.first, item.second)); + buf->CopyFrom(statsProto(item.first, item.second, vector())); } } return prof_dag_protos; @@ -129,7 +146,8 @@ ProfDAGProtos ProfDAGReport::GetPerOperatorCost() const { auto buf = prof_dag_protos.add_stats(); std::string op_output_name = net_name_ + "___" + to_string(op_id) + "___" + op_type; - buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id])); + buf->CopyFrom(statsProto( + op_output_name, time_per_op_total_[op_id], op_extra_info_[op_id])); } } return prof_dag_protos; diff --git a/caffe2/core/prof_dag_counters.h b/caffe2/core/prof_dag_counters.h index 18c8dc82eb1..a1a494ee02e 100644 --- a/caffe2/core/prof_dag_counters.h +++ b/caffe2/core/prof_dag_counters.h @@ -64,10 +64,13 @@ class ProfDAGReport { void PrintStats(); private: - ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats) - const; + ProfDAGProto statsProto( + const std::string& name, + const ProfDAGStats& stats, + const std::vector& op_extra_info) const; std::vector op_types_; + std::vector> op_extra_info_; std::string net_name_; diff --git a/caffe2/proto/prof_dag.proto b/caffe2/proto/prof_dag.proto index 343cff1f66a..c6820d41cfd 100644 --- a/caffe2/proto/prof_dag.proto +++ b/caffe2/proto/prof_dag.proto @@ -42,6 +42,10 @@ message ProfDAGProto { // Blob profiles that this node outputs. repeated BlobProfile output_profile = 5; + + // The extra_info from the operator device option. + repeated string extra_info = 7; + } // Operator profiling information.