mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
keep extra_info of each op in ProfDagStats (#15244)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15244 This DIFF keeps track of the extra_info information attached to each operator. When getPerOpStas() is called, it attaches the extra_info to the result ProfDagStats protobuf. Facebook Net transform attaches a global_op_id which is defined as a tuple of (orig_net_name, original_op_index) to each operator, The global_op_id is encoded as extra_info in each operator. Reviewed By: aazzolini Differential Revision: D13016289 fbshipit-source-id: 3e2719ec7ed0ebe47740b77581c565ff7e79b102
This commit is contained in:
parent
692898fe37
commit
cd3c4a2f1c
3 changed files with 30 additions and 5 deletions
|
|
@ -1,4 +1,5 @@
|
|||
#include "caffe2/core/prof_dag_counters.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
|
@ -10,8 +11,20 @@ ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def) {
|
|||
report_.num_runs_ = 0;
|
||||
auto num_ops = net_def->op_size();
|
||||
report_.op_types_.reserve(num_ops);
|
||||
report_.op_extra_info_.reserve(num_ops);
|
||||
|
||||
for (auto op_id = 0; op_id < num_ops; ++op_id) {
|
||||
report_.op_types_.push_back(net_def->op(op_id).type());
|
||||
vector<std::string> op_extra_info;
|
||||
if (net_def->op(op_id).has_device_option() &&
|
||||
net_def->op(op_id).device_option().extra_info_size() > 0) {
|
||||
for (auto i = 0; i < net_def->op(op_id).device_option().extra_info_size();
|
||||
++i) {
|
||||
auto extra_info_str = net_def->op(op_id).device_option().extra_info(i);
|
||||
op_extra_info.push_back(extra_info_str);
|
||||
}
|
||||
}
|
||||
report_.op_extra_info_.push_back(op_extra_info);
|
||||
}
|
||||
report_.time_per_op_total_.resize(num_ops);
|
||||
}
|
||||
|
|
@ -99,12 +112,16 @@ ProfDAGReport ProfDAGCounters::GetReport() const {
|
|||
|
||||
ProfDAGProto ProfDAGReport::statsProto(
|
||||
const std::string& name,
|
||||
const ProfDAGStats& stats) const {
|
||||
const ProfDAGStats& stats,
|
||||
const std::vector<std::string>& op_extra_info) const {
|
||||
ProfDAGProto stats_proto;
|
||||
const auto& moments = stats.computeMoments();
|
||||
stats_proto.set_mean(moments.first);
|
||||
stats_proto.set_stddev(moments.second);
|
||||
stats_proto.set_name(name);
|
||||
for (auto& extra_info : op_extra_info) {
|
||||
stats_proto.add_extra_info(extra_info);
|
||||
}
|
||||
return stats_proto;
|
||||
}
|
||||
|
||||
|
|
@ -114,7 +131,7 @@ ProfDAGProtos ProfDAGReport::GetOperatorStats() const {
|
|||
if (num_runs_ > 1) {
|
||||
for (auto& item : time_per_op_type_total_) {
|
||||
auto buf = prof_dag_protos.add_stats();
|
||||
buf->CopyFrom(statsProto(item.first, item.second));
|
||||
buf->CopyFrom(statsProto(item.first, item.second, vector<std::string>()));
|
||||
}
|
||||
}
|
||||
return prof_dag_protos;
|
||||
|
|
@ -129,7 +146,8 @@ ProfDAGProtos ProfDAGReport::GetPerOperatorCost() const {
|
|||
auto buf = prof_dag_protos.add_stats();
|
||||
std::string op_output_name =
|
||||
net_name_ + "___" + to_string(op_id) + "___" + op_type;
|
||||
buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
|
||||
buf->CopyFrom(statsProto(
|
||||
op_output_name, time_per_op_total_[op_id], op_extra_info_[op_id]));
|
||||
}
|
||||
}
|
||||
return prof_dag_protos;
|
||||
|
|
|
|||
|
|
@ -64,10 +64,13 @@ class ProfDAGReport {
|
|||
void PrintStats();
|
||||
|
||||
private:
|
||||
ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats)
|
||||
const;
|
||||
ProfDAGProto statsProto(
|
||||
const std::string& name,
|
||||
const ProfDAGStats& stats,
|
||||
const std::vector<std::string>& op_extra_info) const;
|
||||
|
||||
std::vector<std::string> op_types_;
|
||||
std::vector<std::vector<std::string>> op_extra_info_;
|
||||
|
||||
std::string net_name_;
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ message ProfDAGProto {
|
|||
|
||||
// Blob profiles that this node outputs.
|
||||
repeated BlobProfile output_profile = 5;
|
||||
|
||||
// The extra_info from the operator device option.
|
||||
repeated string extra_info = 7;
|
||||
|
||||
}
|
||||
|
||||
// Operator profiling information.
|
||||
|
|
|
|||
Loading…
Reference in a new issue