mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Enable prepacking in subgraph (#5433)
Prepacking in subgraph is not supported currently. We see more and more models with subgraph, which has MatMul, MatMulInteger and other ops. Prepacking can speed up those models significantly.
This commit is contained in:
parent
564da960ce
commit
30cdc74bc0
4 changed files with 218 additions and 61 deletions
|
|
@ -121,7 +121,7 @@ class Node {
|
|||
|
||||
/** Gets the Node's Node::Type. */
|
||||
Node::Type NodeType() const noexcept { return node_type_; }
|
||||
|
||||
|
||||
/** Gets the opset version that the Node's operator was first defined in.
|
||||
@returns Opset version. If -1 the Node's operator has not been set.
|
||||
@remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
|
||||
|
|
@ -1029,13 +1029,12 @@ class Graph {
|
|||
|
||||
/** Returns true if the name is for a value that is coming from outer scope */
|
||||
bool IsOuterScopeValue(const std::string& name) const {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
return resolve_context_.outer_scope_node_args.find(name) != resolve_context_.outer_scope_node_args.cend();
|
||||
#else
|
||||
// we shouldn't have code that calls this in a minimal build
|
||||
ORT_UNUSED_PARAMETER(name);
|
||||
ORT_THROW("Internal error. Outer scope value lookup is not currently supported in a minimal build.");
|
||||
#endif
|
||||
if (!parent_node_) return false;
|
||||
const auto& implicit_input_defs = parent_node_->ImplicitInputDefs();
|
||||
return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
|
||||
[&name](const NodeArg* implicit_input) {
|
||||
return implicit_input->Name() == name;
|
||||
});
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -234,36 +234,39 @@ void SessionState::CleanInitializedTensorsFromGraph() {
|
|||
graph_.CleanAllInitializedTensors();
|
||||
}
|
||||
|
||||
Status SessionState::PrepackInitializedConstantTensors() {
|
||||
// calculate the use count of each value
|
||||
std::unordered_map<std::string, size_t> node_arg_use_count;
|
||||
for (const auto& node : GetGraphViewer().Nodes()) {
|
||||
node.ForEachDef([&](const onnxruntime::NodeArg& node_arg, bool is_input) {
|
||||
if (is_input) {
|
||||
node_arg_use_count[node_arg.Name()]++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status SessionState::PrepackConstantInitializedTensors(std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
|
||||
for (auto& node : GetGraphViewer().Nodes()) {
|
||||
auto kernel = GetMutableKernel(node.Index());
|
||||
int input_idx = 0;
|
||||
for (auto& input_def : node.InputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
const std::string& input_name = input_def->Name();
|
||||
int ort_value_idx;
|
||||
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(input_name, ort_value_idx));
|
||||
if (constant_initialized_tensors_.count(ort_value_idx) &&
|
||||
constant_initialized_tensors_[ort_value_idx].IsTensor()) {
|
||||
bool is_packed = false;
|
||||
const Tensor& const_initialized_tensor = constant_initialized_tensors_[ort_value_idx].Get<Tensor>();
|
||||
ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed));
|
||||
if (is_packed && node_arg_use_count.count(input_name) && --node_arg_use_count[input_name] == 0) {
|
||||
// release the constant intialized tensor
|
||||
initialized_tensors_.erase(ort_value_idx);
|
||||
constant_initialized_tensors_.erase(ort_value_idx);
|
||||
SessionState* st = this;
|
||||
// subgraph can use the value from outer scope,
|
||||
// so it needs to check if current node uses constant initialized tensor from current and outer graphs
|
||||
do {
|
||||
int ort_value_idx;
|
||||
if (st->GetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) {
|
||||
std::unordered_map<int, OrtValue>& constant_initialized_tensors = st->constant_initialized_tensors_;
|
||||
if (constant_initialized_tensors.count(ort_value_idx)) {
|
||||
bool is_packed = false;
|
||||
const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get<Tensor>();
|
||||
ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed));
|
||||
if (is_packed && constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) {
|
||||
// release the constant initialized tensor
|
||||
st->initialized_tensors_.erase(ort_value_idx);
|
||||
constant_initialized_tensors.erase(ort_value_idx);
|
||||
}
|
||||
}
|
||||
// stop searching in 2 cases:
|
||||
// 1. value is not from OuterScope
|
||||
// 2. value is from OuterScope and the current OuterScope has the value
|
||||
if (st != this || !st->graph_.IsOuterScopeValue(input_name)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
st = st->Parent();
|
||||
} while (st);
|
||||
}
|
||||
input_idx++;
|
||||
}
|
||||
|
|
@ -567,10 +570,13 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s
|
|||
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
|
||||
" for attribute ", attribute_name);
|
||||
}
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
|
||||
session_state->parent_ = this;
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
GenerateGraphId();
|
||||
#endif
|
||||
|
||||
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
|
||||
}
|
||||
|
||||
|
|
@ -776,6 +782,27 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
|
|||
}
|
||||
#endif
|
||||
|
||||
// Calculate the use count of a constant initialized tensor, including the use in subgraph.
|
||||
// Note: This function doesn't handle the case below:
|
||||
// The main graph has a constant initializer called X, and the subgraph also has a constant initializer called X, which overrides the X from main graph.
|
||||
// For case like this, the current implementation will calculate the use count as 2, but they could contain completely different values so each should have a use count of 1.
|
||||
// This is a very rare case. If it happens and X is prepacked, the consequence is that X won't be released and memory usage of X won't be saved. This will be fine.
|
||||
static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
for (const auto* arg : node.InputDefs()) {
|
||||
if (arg->Exists() && graph.GetConstantInitializer(arg->Name(), true /*check_outer_scope*/)) {
|
||||
constant_initializers_use_count[arg->Name()]++;
|
||||
}
|
||||
}
|
||||
|
||||
if (node.ContainsSubgraph()) {
|
||||
for (const gsl::not_null<const Graph*>& subgraph : node.GetSubgraphs()) {
|
||||
ComputeConstantInitializerUseCount(*subgraph, constant_initializers_use_count);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const SessionOptions& session_options,
|
||||
|
|
@ -807,15 +834,18 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
|
|||
#endif
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, size_t> constant_initializers_use_count;
|
||||
ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count);
|
||||
return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options,
|
||||
remove_initializers);
|
||||
remove_initializers, constant_initializers_use_count);
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers) {
|
||||
bool remove_initializers,
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
|
||||
CreateGraphInfo();
|
||||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
|
|
@ -868,7 +898,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
session_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisablePrepacking, "0");
|
||||
|
||||
if (disable_prepacking != "1") {
|
||||
ORT_RETURN_IF_ERROR(PrepackInitializedConstantTensors());
|
||||
ORT_RETURN_IF_ERROR(PrepackConstantInitializedTensors(constant_initializers_use_count));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(
|
||||
|
|
@ -896,7 +926,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl(
|
||||
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers));
|
||||
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, constant_initializers_use_count));
|
||||
|
||||
// setup all the info for handling the feeds and fetches used in subgraph execution
|
||||
auto* p_op_kernel = GetMutableKernel(node.Index());
|
||||
|
|
|
|||
|
|
@ -279,6 +279,10 @@ class SessionState {
|
|||
const onnxruntime::experimental::fbs::SessionState* serialized_session_state = nullptr,
|
||||
bool remove_initializers = true);
|
||||
|
||||
SessionState* Parent() {
|
||||
return parent_;
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
|
||||
|
||||
|
|
@ -298,7 +302,7 @@ class SessionState {
|
|||
* Prepack the constant initialized tensors for better performance.
|
||||
* The original constant initialized tensors will be removed to save memory.
|
||||
*/
|
||||
Status PrepackInitializedConstantTensors();
|
||||
Status PrepackConstantInitializedTensors(std::unordered_map<std::string, size_t>& constant_initializers_use_count);
|
||||
|
||||
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
|
||||
|
||||
|
|
@ -315,7 +319,8 @@ class SessionState {
|
|||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers);
|
||||
bool remove_initializers,
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
Status GeneratePatternGroupCache(
|
||||
|
|
@ -421,9 +426,9 @@ class SessionState {
|
|||
std::map<std::vector<int>, std::unordered_set<NodeIndex>> to_be_executed_nodes_;
|
||||
#endif
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
SessionState* parent_ = nullptr;
|
||||
//Assign each graph in each session an unique id.
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
int graph_id_ = 0;
|
||||
int next_graph_id_ = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -189,20 +189,7 @@ class PrePackingTestOpKernel : public OpKernel {
|
|||
}
|
||||
};
|
||||
|
||||
class SessionStatePrepackingTest : public testing::TestWithParam<bool> {};
|
||||
TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
||||
OrtThreadPoolParams to;
|
||||
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
|
||||
ONNX_OPERATOR_SCHEMA(PrePackingTest)
|
||||
.SetDoc("Faking Node for PrePacking")
|
||||
.Input(0, "Input_0", "input 0", "tensor(float)")
|
||||
.Input(1, "Input_1", "input 1", "tensor(float)")
|
||||
.Output(0, "output_0", "docstr for output_0.", "tensor(float)");
|
||||
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
// construct graph
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
static void CreateSimpleGraph(Graph& graph) {
|
||||
// node creation and placement
|
||||
TypeProto type;
|
||||
type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -218,8 +205,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
onnxruntime::NodeArg output_arg("node_0_output_0", &type);
|
||||
outputs.push_back(&output_arg);
|
||||
|
||||
onnxruntime::Node& node = graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
|
||||
|
||||
// add an initializer
|
||||
ONNX_NAMESPACE::TensorProto tensor;
|
||||
|
|
@ -231,6 +217,123 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
static const ONNX_NAMESPACE::GraphProto CreateSubgraph(bool then_branch) {
|
||||
Model model(then_branch ? "If_then" : "If_else", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
const std::string suffix = then_branch ? "0" : "1";
|
||||
|
||||
// graph input has to have type and rank even though it's an outer scope value.
|
||||
TypeProto type_float;
|
||||
type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
// outer scope values
|
||||
auto& if_shared = graph.GetOrCreateNodeArg("if_shared", &type_float);
|
||||
auto& if_input = graph.GetOrCreateNodeArg("if_input_" + suffix, &type_float);
|
||||
|
||||
// add so that we don't end up with it being considered a graph input
|
||||
graph.AddOuterScopeNodeArg("if_shared");
|
||||
graph.AddOuterScopeNodeArg("if_input_" + suffix);
|
||||
|
||||
auto& if_out = graph.GetOrCreateNodeArg("if_output_" + suffix, &type_float);
|
||||
|
||||
inputs = {&if_shared, &if_input};
|
||||
outputs = {&if_out};
|
||||
|
||||
graph.AddNode("if_node_" + suffix, "PrePackingTest", "if node " + suffix, inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
auto& proto = graph.ToGraphProto();
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
static void CreateGraphWithSubgraph(Graph& graph) {
|
||||
TypeProto type_float;
|
||||
type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
{
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
onnxruntime::NodeArg input_0_arg("if_input_0", &type_float);
|
||||
onnxruntime::NodeArg input_1_arg("if_input_1", &type_float);
|
||||
inputs.push_back(&input_0_arg);
|
||||
inputs.push_back(&input_1_arg);
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
onnxruntime::NodeArg output_arg("node_0_output_0", &type_float);
|
||||
outputs.push_back(&output_arg);
|
||||
|
||||
graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
|
||||
}
|
||||
|
||||
{
|
||||
TypeProto type_bool;
|
||||
type_bool.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
type_bool.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
onnxruntime::NodeArg bool_arg("bool_arg", &type_bool);
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
onnxruntime::NodeArg output_arg("output_arg", &type_float);
|
||||
outputs.push_back(&output_arg);
|
||||
|
||||
auto& if_node = graph.AddNode("if", "If", "If node", {&bool_arg}, outputs);
|
||||
|
||||
auto then_proto = CreateSubgraph(true);
|
||||
auto else_proto = CreateSubgraph(false);
|
||||
if_node.AddAttribute("then_branch", then_proto);
|
||||
if_node.AddAttribute("else_branch", else_proto);
|
||||
}
|
||||
|
||||
// add an initializer
|
||||
ONNX_NAMESPACE::TensorProto tensor;
|
||||
tensor.add_dims(1);
|
||||
tensor.add_float_data(1.0f);
|
||||
tensor.set_data_type(TensorProto_DataType_FLOAT);
|
||||
tensor.set_name("if_shared");
|
||||
graph.AddInitializedTensor(tensor);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
static void PlaceAllNodesToCPUEP(Graph& graph) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
if (node.ContainsSubgraph()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
PlaceAllNodesToCPUEP(*subgraph);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PrepackingTestParam {
|
||||
bool test_subgraph;
|
||||
bool test_prepacking;
|
||||
};
|
||||
|
||||
class SessionStatePrepackingTest : public testing::TestWithParam<PrepackingTestParam> {};
|
||||
TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
||||
PrepackingTestParam test_param = GetParam();
|
||||
|
||||
OrtThreadPoolParams to;
|
||||
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
|
||||
ONNX_OPERATOR_SCHEMA(PrePackingTest)
|
||||
.SetDoc("Faking Node for PrePacking")
|
||||
.Input(0, "Input_0", "input 0", "tensor(float)")
|
||||
.Input(1, "Input_1", "input 1", "tensor(float)")
|
||||
.Output(0, "output_0", "docstr for output_0.", "tensor(float)");
|
||||
|
||||
ExecutionProviders execution_providers;
|
||||
auto cpu_execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
|
||||
|
|
@ -238,7 +341,21 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
|
||||
DataTransferManager dtm;
|
||||
profiling::Profiler profiler;
|
||||
SessionState session_state(graph,
|
||||
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = 11;
|
||||
Model model("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
|
||||
// onnxruntime::Model model("graph_main", false, DefaultLoggingManager().DefaultLogger());
|
||||
if (test_param.test_subgraph) {
|
||||
CreateGraphWithSubgraph(model.MainGraph());
|
||||
} else {
|
||||
CreateSimpleGraph(model.MainGraph());
|
||||
}
|
||||
|
||||
SessionState session_state(model.MainGraph(),
|
||||
execution_providers,
|
||||
true, /*enable_mem_pattern*/
|
||||
tp.get(),
|
||||
|
|
@ -248,7 +365,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
profiler);
|
||||
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
Status status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
|
||||
auto kernel_def = KernelDefBuilder().SetName("PrePackingTest").Provider(kCpuExecutionProvider).SinceVersion(1).Build();
|
||||
|
|
@ -257,19 +374,25 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
[](const OpKernelInfo& info) -> OpKernel* { return new PrePackingTestOpKernel(info); })));
|
||||
kernel_registry_manager.RegisterKernelRegistry(kernel_registry);
|
||||
|
||||
PlaceAllNodesToCPUEP(model.MainGraph());
|
||||
|
||||
SessionOptions sess_options;
|
||||
bool use_prepacking = GetParam();
|
||||
sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = use_prepacking ? "0" : "1";
|
||||
sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1";
|
||||
ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string<PATH_CHAR_TYPE>(),
|
||||
kernel_registry_manager,
|
||||
sess_options));
|
||||
|
||||
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
|
||||
// check prepacking
|
||||
ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1));
|
||||
ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(true, false));
|
||||
INSTANTIATE_TEST_SUITE_P(SessionStateTests,
|
||||
SessionStatePrepackingTest,
|
||||
testing::Values(PrepackingTestParam{false, false},
|
||||
PrepackingTestParam{false, true},
|
||||
PrepackingTestParam{true, false},
|
||||
PrepackingTestParam{true, true}));
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue