PR comments (#3374)

* PR comments

* PR comments

* PR comments

* PR comments

* PR comments

* PR comments

* PR comments

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-04-01 10:36:16 -07:00 committed by GitHub
parent 614eb438ae
commit 2ce90cff4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 46 deletions

View file

@ -796,6 +796,7 @@ if (onnxruntime_USE_DML)
endif()
if (onnxruntime_ENABLE_TRAINING)
add_compile_definitions(ENABLE_TRAINING)
if (onnxruntime_USE_HOROVOD)
if (WIN32)
message( FATAL_ERROR "Horovod is not supported on Windows." )

View file

@ -30,6 +30,3 @@ if (WIN32)
set_target_properties(onnxruntime_framework PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/ConfigureVisualStudioCodeAnalysis.props)
endif()
if(onnxruntime_ENABLE_TRAINING)
target_compile_definitions(onnxruntime_framework PUBLIC ENABLE_TRAINING)
endif()

View file

@ -187,12 +187,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
node.Name());
std::string node_name = node.Name();
if (node_name.empty()) {
// Node name field is often blank in execution graph, so derive something meaningful for profile traces and logs.
node_name = node.OpType() + "_" + std::to_string(node_index);
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
LARGE_INTEGER kernel_start;
QueryPerformanceCounter(&kernel_start);
@ -241,22 +235,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node());
#endif
const std::string node_name_for_profiling = [&]() -> std::string {
if (!is_profiler_enabled) return {};
// Derive something meaningful for profile traces and logs if node name field is blank in execution graph
return node.Name().empty() ? MakeString(node.OpType(), "_", node_index) : node.Name();
}();
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_before",
node_name_for_profiling + "_fence_before",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
// call compute on the kernel
VLOGS(logger, 1) << "Computing kernel: " << node_name;
VLOGS(logger, 1) << "Computing kernel: " << node_name_for_profiling;
kernel_begin_time = session_state.Profiler().StartTime();
}
if (is_profiler_enabled) {
// Calculate total input sizes for this operation.
CalculateTotalInputSizes(&op_kernel_context, p_op_kernel,
input_activation_sizes,input_parameter_sizes, node_name);
input_activation_sizes,input_parameter_sizes, node_name_for_profiling);
}
#ifdef CONCURRENCY_VISUALIZER
@ -286,12 +284,12 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
if (is_profiler_enabled) {
// Calculate total output sizes for this operation.
CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name);
CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name_for_profiling);
#if defined(TRACE_EXECUTION)
// Trace execution step.
const Node& node = p_op_kernel->Node();
std::cout << "Executed op kernel node " << node_name
std::cout << "Executed op kernel node " << node_name_for_profiling
<< " Index=" << node.Index()
<< " OpType=" << node.OpType()
<< " Name=" << node.Name()
@ -302,7 +300,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
#endif
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_kernel_time",
node_name_for_profiling + "_kernel_time",
kernel_begin_time,
// Log additional operation args / info.
{
@ -356,7 +354,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
#endif
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_after",
node_name_for_profiling + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}
@ -366,7 +364,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
#endif
// free ml-values corresponding to this node
VLOGS(logger, 1) << "Releasing node ML values after computing kernel: " << node_name;
VLOGS(logger, 1) << "Releasing node ML values.";
ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
}

View file

@ -136,6 +136,7 @@ const std::unordered_map<int, OrtValue>& SessionState::GetConstantInitializedTen
return constant_initialized_tensors_;
}
#ifdef ENABLE_TRAINING
Status SessionState::GetInitializedTensors(
const std::unordered_set<std::string>& interested_weights,
bool allow_missing_weights, NameMLValMap& retrieved_weights) const {
@ -162,6 +163,7 @@ NameMLValMap SessionState::GetInitializedTensors(const std::unordered_set<std::s
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return result;
}
#endif
SessionState& SessionState::SetLogger(const logging::Logger& logger) {
logger_ = &logger;
@ -186,6 +188,7 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
return key;
}
#ifdef ENABLE_TRAINING
namespace {
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
for (const auto* input : graph.GetInputs()) {
@ -280,6 +283,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
}
return Status::OK();
}
#endif
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
const std::vector<int>& feed_mlvalue_idxs) const {
@ -459,7 +463,9 @@ const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
}
void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue_idxs) {
if (to_be_executed_nodes_.find(fetch_mlvalue_idxs) != to_be_executed_nodes_.end())
std::vector<int> sorted_idxs = fetch_mlvalue_idxs;
std::sort(sorted_idxs.begin(), sorted_idxs.end());
if (to_be_executed_nodes_.find(sorted_idxs) != to_be_executed_nodes_.end())
return;
const Graph& graph = GetGraphViewer()->GetGraph();
@ -471,10 +477,8 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
for (auto idx : fetch_mlvalue_idxs) {
std::string node_arg_name;
if (!this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name).IsOK()) {
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
return;
}
const auto status = this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
auto ending_node = graph.GetProducerNode(node_arg_name);
nodes.push_back(ending_node);
}
@ -482,12 +486,14 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
// Reversely traverse to get reachable nodes.
graph.ReverseDFSFrom(
nodes, {}, [&reachable_nodes](const Node* n) { reachable_nodes.insert(n->Index()); });
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
to_be_executed_nodes_.insert(std::make_pair(sorted_idxs, reachable_nodes));
}
const std::unordered_set<NodeIndex>* SessionState::GetToBeExecutedNodes(
const std::vector<int>& fetch_mlvalue_idxs) const {
auto it = to_be_executed_nodes_.find(fetch_mlvalue_idxs);
std::vector<int> sorted_idxs = fetch_mlvalue_idxs;
std::sort(sorted_idxs.begin(), sorted_idxs.end());
auto it = to_be_executed_nodes_.find(sorted_idxs);
return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr;
}

View file

@ -116,6 +116,7 @@ class SessionState {
*/
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
#ifdef ENABLE_TRAINING
/**
Get some initialized tensors (weights).
@param interested_weights The names of the weights to retrieve.
@ -133,6 +134,7 @@ class SessionState {
Any names in interested_weights with no corresponding weight are ignored.
*/
NameMLValMap GetInitializedTensors(const std::unordered_set<std::string>& interested_weights) const;
#endif
// execution plan
void SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan);
@ -249,10 +251,12 @@ class SessionState {
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
#ifdef ENABLE_TRAINING
Status GeneratePatternGroupCache(
const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
const std::vector<int>& feed_mlvalue_idxs,
MemoryPatternGroup* output) const;
#endif
// cache of the constructed kernels to avoid spending construction
// time per executor

View file

@ -10,16 +10,21 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
static std::pair<bool, Node*> IsInputTranspose(const Graph& graph, NodeArg& node_arg) {
const auto& trans_node = graph.GetProducerNode(node_arg.Name());
static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
Node* trans_node = graph.GetMutableProducerNode(node_arg.Name());
if (trans_node == nullptr || trans_node->OpType() != "Transpose") {
return std::make_pair(false, nullptr);
return nullptr;
}
// if the node has Graph output, skip it too
if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) {
return nullptr;
}
auto perms = RetrieveValues<int64_t>(trans_node->GetAttributes().at("perm"));
int64_t rank = perms.size();
if (rank < 2) {
return std::make_pair(false, nullptr);
return nullptr;
}
bool is_trans_on_last_two_dims = true;
@ -31,14 +36,14 @@ static std::pair<bool, Node*> IsInputTranspose(const Graph& graph, NodeArg& node
}
if (is_trans_on_last_two_dims) {
is_trans_on_last_two_dims = (int64_t)perms[rank - 2] == rank - 1 && (int64_t)perms[rank - 1] == rank - 2;
is_trans_on_last_two_dims = perms[rank - 2] == rank - 1 && perms[rank - 1] == rank - 2;
}
if (!is_trans_on_last_two_dims) {
return std::make_pair(false, nullptr);
return nullptr;
}
return std::make_pair(true, const_cast<Node*>(trans_node));
return trans_node;
}
static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_map<NodeArg*, size_t>& count_map) {
@ -72,27 +77,27 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
}
NodeArg* left_input = node.MutableInputDefs()[0];
auto left = IsInputTranspose(graph, *left_input);
auto left = GetTransposeNodeFromOutput(graph, *left_input);
NodeArg* right_input = node.MutableInputDefs()[1];
auto right = IsInputTranspose(graph, *right_input);
auto right = GetTransposeNodeFromOutput(graph, *right_input);
if (!left.first && !right.first) {
if (!left && !right) {
continue;
}
if (left.first) {
if (left) {
size_t left_consumers = UpdateConsumerCount(graph, left_input, consumer_count);
if (left_consumers == 0)
removed_nodes.push_front(left.second->Index());
left_input = left.second->MutableInputDefs()[0];
removed_nodes.push_front(left->Index());
left_input = left->MutableInputDefs()[0];
}
if (right.first) {
if (right) {
size_t right_consumers = UpdateConsumerCount(graph, right_input, consumer_count);
if (right_consumers == 0)
removed_nodes.push_front(right.second->Index());
right_input = right.second->MutableInputDefs()[0];
removed_nodes.push_front(right->Index());
right_input = right->MutableInputDefs()[0];
}
const std::vector<NodeArg*> input_defs{left_input, right_input};
@ -103,16 +108,16 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
"fused MatMul and Transpose ",
input_defs,
output_defs, {}, kMSDomain);
bool transpose_left = left.first;
bool transpose_left = (left != nullptr);
if (node.OpType() == "TransposeMatMul") {
transpose_left ^= static_cast<bool>(node.GetAttributes().at("transA").i());
}
bool transpose_right = right.first;
bool transpose_right = (right != nullptr);
if (node.OpType() == "TransposeMatMul") {
transpose_right ^= static_cast<bool>(node.GetAttributes().at("transB").i());
}
matmul_node.AddAttribute("transA", (int64_t)transpose_left);
matmul_node.AddAttribute("transB", (int64_t)transpose_right);
matmul_node.AddAttribute("transA", static_cast<int64_t>(transpose_left));
matmul_node.AddAttribute("transB", static_cast<int64_t>(transpose_right));
// Assign provider to this new node. Provider should be same as the provider for old node.
matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());