keep extra_info of each op in ProfDagStats (#15244)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15244

This DIFF keeps track of the extra_info information attached to each operator. When getPerOpStas() is called, it attaches the extra_info to the result ProfDagStats protobuf.

Facebook
Net transform attaches a global_op_id which is defined as a tuple of (orig_net_name, original_op_index) to each operator,
The global_op_id is encoded as extra_info in each operator.

Reviewed By: aazzolini

Differential Revision: D13016289

fbshipit-source-id: 3e2719ec7ed0ebe47740b77581c565ff7e79b102
This commit is contained in:
Dong Li 2018-12-28 15:00:41 -08:00 committed by Facebook Github Bot
parent 692898fe37
commit cd3c4a2f1c
3 changed files with 30 additions and 5 deletions

View file

@ -1,4 +1,5 @@
#include "caffe2/core/prof_dag_counters.h"
#include "caffe2/utils/string_utils.h"
#include <ostream>
#include <sstream>
@ -10,8 +11,20 @@ ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def) {
report_.num_runs_ = 0;
auto num_ops = net_def->op_size();
report_.op_types_.reserve(num_ops);
report_.op_extra_info_.reserve(num_ops);
for (auto op_id = 0; op_id < num_ops; ++op_id) {
report_.op_types_.push_back(net_def->op(op_id).type());
vector<std::string> op_extra_info;
if (net_def->op(op_id).has_device_option() &&
net_def->op(op_id).device_option().extra_info_size() > 0) {
for (auto i = 0; i < net_def->op(op_id).device_option().extra_info_size();
++i) {
auto extra_info_str = net_def->op(op_id).device_option().extra_info(i);
op_extra_info.push_back(extra_info_str);
}
}
report_.op_extra_info_.push_back(op_extra_info);
}
report_.time_per_op_total_.resize(num_ops);
}
@ -99,12 +112,16 @@ ProfDAGReport ProfDAGCounters::GetReport() const {
ProfDAGProto ProfDAGReport::statsProto(
const std::string& name,
const ProfDAGStats& stats) const {
const ProfDAGStats& stats,
const std::vector<std::string>& op_extra_info) const {
ProfDAGProto stats_proto;
const auto& moments = stats.computeMoments();
stats_proto.set_mean(moments.first);
stats_proto.set_stddev(moments.second);
stats_proto.set_name(name);
for (auto& extra_info : op_extra_info) {
stats_proto.add_extra_info(extra_info);
}
return stats_proto;
}
@ -114,7 +131,7 @@ ProfDAGProtos ProfDAGReport::GetOperatorStats() const {
if (num_runs_ > 1) {
for (auto& item : time_per_op_type_total_) {
auto buf = prof_dag_protos.add_stats();
buf->CopyFrom(statsProto(item.first, item.second));
buf->CopyFrom(statsProto(item.first, item.second, vector<std::string>()));
}
}
return prof_dag_protos;
@ -129,7 +146,8 @@ ProfDAGProtos ProfDAGReport::GetPerOperatorCost() const {
auto buf = prof_dag_protos.add_stats();
std::string op_output_name =
net_name_ + "___" + to_string(op_id) + "___" + op_type;
buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
buf->CopyFrom(statsProto(
op_output_name, time_per_op_total_[op_id], op_extra_info_[op_id]));
}
}
return prof_dag_protos;

View file

@ -64,10 +64,13 @@ class ProfDAGReport {
void PrintStats();
private:
ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats)
const;
ProfDAGProto statsProto(
const std::string& name,
const ProfDAGStats& stats,
const std::vector<std::string>& op_extra_info) const;
std::vector<std::string> op_types_;
std::vector<std::vector<std::string>> op_extra_info_;
std::string net_name_;

View file

@ -42,6 +42,10 @@ message ProfDAGProto {
// Blob profiles that this node outputs.
repeated BlobProfile output_profile = 5;
// The extra_info from the operator device option.
repeated string extra_info = 7;
}
// Operator profiling information.