onnxruntime/onnxruntime/test/optimizer/optimizer_test.cc
Tang, Cheng a81faee41e
Multi-stream execution support (#13495)
**Description**: This PR including following works:
1. provide stream and related synchronization abstractions in
onnxruntime.
2. enhance onnxruntime's execution planner / executor / memory arena to
support execute multiple streams in parallel.
3. deprecate the parallel executor for cpu.
4. deprecate the Fence mechanism. 
5. update the cuda / tensorrt EP to support the stream mechanism,
support running different request in different cuda stream.

**Motivation and Context**
- Why is this change required? 
currently, the execution plan is just a linear list of those primitives,
ort will execute them step by step. For any given graph, ORT will
serialize it to a fixed execution order. This sequential execution
design simplifies most scenarios, but it has the following limitations:
1. it is difficult to enable inter-node parallelization, we have a
half-baked parallel executor but it is very difficult to make it work
with GPU.
2. The fence mechanism can work with single gpu stream + cpu thread
case, but when extend to multiple stream, it is difficult to manage the
cross GPU stream synchronizations.
3. our cuda EP rely on the BFCArena to make the memory management work
with the GPU async kernels, but current BFCArena is not aware of the
streams, so it doesn't behavior correctly when run with multiple
streams.

This PR enhance our existing execution plan and executor to support
multiple stream execution. we use an unified algorithm to mange both
single stream and multiple stream scenarios.
This PR mainly focus on the infrastructure support for multiple stream
execution, that is said, given a valid stream assignment, onnxruntime
can execute it correctly. How to generate a good stream assignment for a
given model will be in the future PR.

Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Cheng Tang <chenta@microsoft.com>
Co-authored-by: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com>
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: cao lei <jslhcl@gmail.com>
Co-authored-by: Lei Cao <leca@microsoft.com>
2022-12-15 07:39:29 -08:00

119 lines
4.5 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/framework/data_types.h"
#include "core/framework/ort_value.h"
#include "core/framework/op_kernel.h"
#include "core/util/math.h"
#include "core/platform/env.h"
#include "test/framework/test_utils.h"
#include "test/capturing_sink.h"
#include "test/test_environment.h"
#include "asserts.h"
#include "gtest/gtest.h"
using namespace std;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
TEST(OptimizerTest, Basic) {
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
constexpr int tensor_dim = 10;
constexpr int input_num = 2;
TensorProto initializer_tensor[input_num];
std::vector<std::unique_ptr<NodeArg>> inputs(input_num);
std::vector<std::unique_ptr<NodeArg>> outputs(1);
InitializedTensorSet initialized_tensor_set;
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_dim);
for (int i = 0; i < input_num; i++) {
string name("input_" + std::to_string(i));
inputs[i] = std::make_unique<NodeArg>(name, &tensor_int32);
initializer_tensor[i].set_name(inputs[i]->Name());
initializer_tensor[i].add_dims(tensor_dim);
initializer_tensor[i].set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
for (int j = 0; j < tensor_dim; j++) {
initializer_tensor[i].add_int32_data((i + 1) * j);
}
initialized_tensor_set[name] = &initializer_tensor[i];
}
outputs[0] = std::make_unique<NodeArg>("out", &tensor_int32);
std::vector<NodeArg*> tmp_inputs{inputs[0].get(), inputs[1].get()};
std::vector<NodeArg*> tmp_outputs{outputs[0].get()};
graph.AddNode("a", "Add", "a", tmp_inputs, tmp_outputs);
ASSERT_STATUS_OK(graph.Resolve());
std::vector<const Node*> nodes;
for (auto& node : graph.Nodes()) {
nodes.push_back(&node);
}
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
#if !defined(DISABLE_SPARSE_TENSORS)
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
graph.ModelPath(),
*cpu_execution_provider.get(),
[&graph](const std::string& name) -> bool {
return graph.IsSparseInitializer(name);
});
#else
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
graph.ModelPath(),
*cpu_execution_provider.get(),
[](std::string const&) { return false; });
#endif //! defined(DISABLE_SPARSE_TENSORS)
std::vector<int> fetch_mlvalue_idxs{info.GetMLValueIndex("out")};
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
const logging::Logger& logger = DefaultLoggingManager().DefaultLogger();
for (auto& node : graph.Nodes()) {
auto kernel = info.CreateKernel(&node);
// kernel can only be a nullptr if a CPU kernel implementation has been removed,
// if that is the case, OpKernelContext instance construction will throw in the next step
// and fail the test
#ifdef _WIN32
#pragma warning(push)
#pragma warning(disable : 6387)
#endif
OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger);
#ifdef _WIN32
#pragma warning(pop)
#endif
auto st = kernel->Compute(&op_kernel_context);
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(frame.GetOutputs(fetches));
auto& tensor = fetches[0].Get<Tensor>();
const std::vector<int32_t> found(tensor.Data<int32_t>(), tensor.Data<int32_t>() + tensor_dim);
std::vector<int32_t> expected;
for (int j = 0; j < tensor_dim; j++) {
expected.push_back(3 * j);
}
ASSERT_EQ(expected, found);
}
}
} // namespace test
} // namespace onnxruntime