mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
PR comments (#3374)
* PR comments * PR comments * PR comments * PR comments * PR comments * PR comments * PR comments Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
614eb438ae
commit
2ce90cff4c
6 changed files with 57 additions and 46 deletions
|
|
@ -796,6 +796,7 @@ if (onnxruntime_USE_DML)
|
|||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
add_compile_definitions(ENABLE_TRAINING)
|
||||
if (onnxruntime_USE_HOROVOD)
|
||||
if (WIN32)
|
||||
message( FATAL_ERROR "Horovod is not supported on Windows." )
|
||||
|
|
|
|||
|
|
@ -30,6 +30,3 @@ if (WIN32)
|
|||
set_target_properties(onnxruntime_framework PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/ConfigureVisualStudioCodeAnalysis.props)
|
||||
endif()
|
||||
|
||||
if(onnxruntime_ENABLE_TRAINING)
|
||||
target_compile_definitions(onnxruntime_framework PUBLIC ENABLE_TRAINING)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -187,12 +187,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
|
||||
node.Name());
|
||||
|
||||
std::string node_name = node.Name();
|
||||
if (node_name.empty()) {
|
||||
// Node name field is often blank in execution graph, so derive something meaningful for profile traces and logs.
|
||||
node_name = node.OpType() + "_" + std::to_string(node_index);
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
LARGE_INTEGER kernel_start;
|
||||
QueryPerformanceCounter(&kernel_start);
|
||||
|
|
@ -241,22 +235,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node());
|
||||
#endif
|
||||
|
||||
const std::string node_name_for_profiling = [&]() -> std::string {
|
||||
if (!is_profiler_enabled) return {};
|
||||
// Derive something meaningful for profile traces and logs if node name field is blank in execution graph
|
||||
return node.Name().empty() ? MakeString(node.OpType(), "_", node_index) : node.Name();
|
||||
}();
|
||||
|
||||
if (is_profiler_enabled) {
|
||||
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
|
||||
node_name + "_fence_before",
|
||||
node_name_for_profiling + "_fence_before",
|
||||
sync_time_begin,
|
||||
{{"op_name", p_op_kernel->KernelDef().OpName()}});
|
||||
|
||||
// call compute on the kernel
|
||||
VLOGS(logger, 1) << "Computing kernel: " << node_name;
|
||||
VLOGS(logger, 1) << "Computing kernel: " << node_name_for_profiling;
|
||||
|
||||
kernel_begin_time = session_state.Profiler().StartTime();
|
||||
}
|
||||
|
||||
if (is_profiler_enabled) {
|
||||
// Calculate total input sizes for this operation.
|
||||
CalculateTotalInputSizes(&op_kernel_context, p_op_kernel,
|
||||
input_activation_sizes,input_parameter_sizes, node_name);
|
||||
input_activation_sizes,input_parameter_sizes, node_name_for_profiling);
|
||||
}
|
||||
|
||||
#ifdef CONCURRENCY_VISUALIZER
|
||||
|
|
@ -286,12 +284,12 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
|
||||
if (is_profiler_enabled) {
|
||||
// Calculate total output sizes for this operation.
|
||||
CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name);
|
||||
CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name_for_profiling);
|
||||
|
||||
#if defined(TRACE_EXECUTION)
|
||||
// Trace execution step.
|
||||
const Node& node = p_op_kernel->Node();
|
||||
std::cout << "Executed op kernel node " << node_name
|
||||
std::cout << "Executed op kernel node " << node_name_for_profiling
|
||||
<< " Index=" << node.Index()
|
||||
<< " OpType=" << node.OpType()
|
||||
<< " Name=" << node.Name()
|
||||
|
|
@ -302,7 +300,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
#endif
|
||||
|
||||
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
|
||||
node_name + "_kernel_time",
|
||||
node_name_for_profiling + "_kernel_time",
|
||||
kernel_begin_time,
|
||||
// Log additional operation args / info.
|
||||
{
|
||||
|
|
@ -356,7 +354,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
#endif
|
||||
if (is_profiler_enabled) {
|
||||
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
|
||||
node_name + "_fence_after",
|
||||
node_name_for_profiling + "_fence_after",
|
||||
sync_time_begin,
|
||||
{{"op_name", p_op_kernel->KernelDef().OpName()}});
|
||||
}
|
||||
|
|
@ -366,7 +364,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
#endif
|
||||
|
||||
// free ml-values corresponding to this node
|
||||
VLOGS(logger, 1) << "Releasing node ML values after computing kernel: " << node_name;
|
||||
VLOGS(logger, 1) << "Releasing node ML values.";
|
||||
ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ const std::unordered_map<int, OrtValue>& SessionState::GetConstantInitializedTen
|
|||
return constant_initialized_tensors_;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
Status SessionState::GetInitializedTensors(
|
||||
const std::unordered_set<std::string>& interested_weights,
|
||||
bool allow_missing_weights, NameMLValMap& retrieved_weights) const {
|
||||
|
|
@ -162,6 +163,7 @@ NameMLValMap SessionState::GetInitializedTensors(const std::unordered_set<std::s
|
|||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
SessionState& SessionState::SetLogger(const logging::Logger& logger) {
|
||||
logger_ = &logger;
|
||||
|
|
@ -186,6 +188,7 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
|
|||
return key;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
namespace {
|
||||
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
|
||||
for (const auto* input : graph.GetInputs()) {
|
||||
|
|
@ -280,6 +283,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
|
||||
const std::vector<int>& feed_mlvalue_idxs) const {
|
||||
|
|
@ -459,7 +463,9 @@ const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
|
|||
}
|
||||
|
||||
void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue_idxs) {
|
||||
if (to_be_executed_nodes_.find(fetch_mlvalue_idxs) != to_be_executed_nodes_.end())
|
||||
std::vector<int> sorted_idxs = fetch_mlvalue_idxs;
|
||||
std::sort(sorted_idxs.begin(), sorted_idxs.end());
|
||||
if (to_be_executed_nodes_.find(sorted_idxs) != to_be_executed_nodes_.end())
|
||||
return;
|
||||
|
||||
const Graph& graph = GetGraphViewer()->GetGraph();
|
||||
|
|
@ -471,10 +477,8 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
|
|||
|
||||
for (auto idx : fetch_mlvalue_idxs) {
|
||||
std::string node_arg_name;
|
||||
if (!this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name).IsOK()) {
|
||||
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
|
||||
return;
|
||||
}
|
||||
const auto status = this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name);
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
auto ending_node = graph.GetProducerNode(node_arg_name);
|
||||
nodes.push_back(ending_node);
|
||||
}
|
||||
|
|
@ -482,12 +486,14 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
|
|||
// Reversely traverse to get reachable nodes.
|
||||
graph.ReverseDFSFrom(
|
||||
nodes, {}, [&reachable_nodes](const Node* n) { reachable_nodes.insert(n->Index()); });
|
||||
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
|
||||
to_be_executed_nodes_.insert(std::make_pair(sorted_idxs, reachable_nodes));
|
||||
}
|
||||
|
||||
const std::unordered_set<NodeIndex>* SessionState::GetToBeExecutedNodes(
|
||||
const std::vector<int>& fetch_mlvalue_idxs) const {
|
||||
auto it = to_be_executed_nodes_.find(fetch_mlvalue_idxs);
|
||||
std::vector<int> sorted_idxs = fetch_mlvalue_idxs;
|
||||
std::sort(sorted_idxs.begin(), sorted_idxs.end());
|
||||
auto it = to_be_executed_nodes_.find(sorted_idxs);
|
||||
return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ class SessionState {
|
|||
*/
|
||||
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
/**
|
||||
Get some initialized tensors (weights).
|
||||
@param interested_weights The names of the weights to retrieve.
|
||||
|
|
@ -133,6 +134,7 @@ class SessionState {
|
|||
Any names in interested_weights with no corresponding weight are ignored.
|
||||
*/
|
||||
NameMLValMap GetInitializedTensors(const std::unordered_set<std::string>& interested_weights) const;
|
||||
#endif
|
||||
|
||||
// execution plan
|
||||
void SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan);
|
||||
|
|
@ -249,10 +251,12 @@ class SessionState {
|
|||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
Status GeneratePatternGroupCache(
|
||||
const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
|
||||
const std::vector<int>& feed_mlvalue_idxs,
|
||||
MemoryPatternGroup* output) const;
|
||||
#endif
|
||||
|
||||
// cache of the constructed kernels to avoid spending construction
|
||||
// time per executor
|
||||
|
|
|
|||
|
|
@ -10,16 +10,21 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
static std::pair<bool, Node*> IsInputTranspose(const Graph& graph, NodeArg& node_arg) {
|
||||
const auto& trans_node = graph.GetProducerNode(node_arg.Name());
|
||||
static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
|
||||
Node* trans_node = graph.GetMutableProducerNode(node_arg.Name());
|
||||
if (trans_node == nullptr || trans_node->OpType() != "Transpose") {
|
||||
return std::make_pair(false, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// if the node has Graph output, skip it too
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto perms = RetrieveValues<int64_t>(trans_node->GetAttributes().at("perm"));
|
||||
int64_t rank = perms.size();
|
||||
if (rank < 2) {
|
||||
return std::make_pair(false, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool is_trans_on_last_two_dims = true;
|
||||
|
|
@ -31,14 +36,14 @@ static std::pair<bool, Node*> IsInputTranspose(const Graph& graph, NodeArg& node
|
|||
}
|
||||
|
||||
if (is_trans_on_last_two_dims) {
|
||||
is_trans_on_last_two_dims = (int64_t)perms[rank - 2] == rank - 1 && (int64_t)perms[rank - 1] == rank - 2;
|
||||
is_trans_on_last_two_dims = perms[rank - 2] == rank - 1 && perms[rank - 1] == rank - 2;
|
||||
}
|
||||
|
||||
if (!is_trans_on_last_two_dims) {
|
||||
return std::make_pair(false, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::make_pair(true, const_cast<Node*>(trans_node));
|
||||
return trans_node;
|
||||
}
|
||||
|
||||
static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_map<NodeArg*, size_t>& count_map) {
|
||||
|
|
@ -72,27 +77,27 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
}
|
||||
|
||||
NodeArg* left_input = node.MutableInputDefs()[0];
|
||||
auto left = IsInputTranspose(graph, *left_input);
|
||||
auto left = GetTransposeNodeFromOutput(graph, *left_input);
|
||||
|
||||
NodeArg* right_input = node.MutableInputDefs()[1];
|
||||
auto right = IsInputTranspose(graph, *right_input);
|
||||
auto right = GetTransposeNodeFromOutput(graph, *right_input);
|
||||
|
||||
if (!left.first && !right.first) {
|
||||
if (!left && !right) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left.first) {
|
||||
if (left) {
|
||||
size_t left_consumers = UpdateConsumerCount(graph, left_input, consumer_count);
|
||||
if (left_consumers == 0)
|
||||
removed_nodes.push_front(left.second->Index());
|
||||
left_input = left.second->MutableInputDefs()[0];
|
||||
removed_nodes.push_front(left->Index());
|
||||
left_input = left->MutableInputDefs()[0];
|
||||
}
|
||||
|
||||
if (right.first) {
|
||||
if (right) {
|
||||
size_t right_consumers = UpdateConsumerCount(graph, right_input, consumer_count);
|
||||
if (right_consumers == 0)
|
||||
removed_nodes.push_front(right.second->Index());
|
||||
right_input = right.second->MutableInputDefs()[0];
|
||||
removed_nodes.push_front(right->Index());
|
||||
right_input = right->MutableInputDefs()[0];
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> input_defs{left_input, right_input};
|
||||
|
|
@ -103,16 +108,16 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
"fused MatMul and Transpose ",
|
||||
input_defs,
|
||||
output_defs, {}, kMSDomain);
|
||||
bool transpose_left = left.first;
|
||||
bool transpose_left = (left != nullptr);
|
||||
if (node.OpType() == "TransposeMatMul") {
|
||||
transpose_left ^= static_cast<bool>(node.GetAttributes().at("transA").i());
|
||||
}
|
||||
bool transpose_right = right.first;
|
||||
bool transpose_right = (right != nullptr);
|
||||
if (node.OpType() == "TransposeMatMul") {
|
||||
transpose_right ^= static_cast<bool>(node.GetAttributes().at("transB").i());
|
||||
}
|
||||
matmul_node.AddAttribute("transA", (int64_t)transpose_left);
|
||||
matmul_node.AddAttribute("transB", (int64_t)transpose_right);
|
||||
matmul_node.AddAttribute("transA", static_cast<int64_t>(transpose_left));
|
||||
matmul_node.AddAttribute("transB", static_cast<int64_t>(transpose_right));
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue