mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
subgraph type override handling and unit test (#3560)
* unit test for subgraph type override * unit test - re-wire input properly to subgraph * update args Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
2cb8cb816f
commit
ca1bbff5d4
6 changed files with 184 additions and 35 deletions
|
|
@ -987,7 +987,8 @@ class Graph {
|
|||
// perform type and shape inferencing on the subgraph and Resolve to validate
|
||||
static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
||||
const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
|
||||
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types);
|
||||
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
|
||||
const Graph::ResolveOptions& options);
|
||||
|
||||
// Apply type-inference and type-checking to all inputs and initializers:
|
||||
common::Status TypeCheckInputsAndInitializers();
|
||||
|
|
|
|||
|
|
@ -67,15 +67,17 @@ class NodeArg {
|
|||
/** Validate and merge type [and shape] info from input_type.
|
||||
@param strict If true, the shape update will fail if there are incompatible values.
|
||||
If false, will be lenient and merge only shape info that can be validly processed.
|
||||
@param override_types If true, resolve the two inputs or two outputs type when different
|
||||
@returns Success unless there is existing type or shape info that can't be successfully updated. */
|
||||
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger);
|
||||
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);
|
||||
|
||||
/** Validate and merge type [and shape] info from node_arg.
|
||||
@param strict If true, the shape update will fail if there are incompatible values.
|
||||
If false, will be lenient and merge only shape info that can be validly processed.
|
||||
@param override_types If true, resolve the two inputs or two outputs type when different
|
||||
@returns Success unless there is existing type or shape info that can't be successfully updated. */
|
||||
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger);
|
||||
|
||||
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
|
||||
|
||||
/** Gets this NodeArg as a ValueInfoProto. */
|
||||
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -208,7 +208,8 @@ void NodeArg::ClearShape() {
|
|||
}
|
||||
}
|
||||
|
||||
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) {
|
||||
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types,
|
||||
const logging::Logger& logger) {
|
||||
if (!utils::HasType(node_arg_info_)) {
|
||||
*node_arg_info_.mutable_type() = input_type;
|
||||
type_ = DataTypeUtils::ToType(node_arg_info_.type());
|
||||
|
|
@ -229,10 +230,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
|
|||
const auto& input_tensor_elem_type = input_tensor_type.elem_type();
|
||||
const auto& current_tensor_elem_type = current_type.tensor_type().elem_type();
|
||||
|
||||
if (input_tensor_elem_type != current_tensor_elem_type)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
|
||||
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
|
||||
static_cast<TensorProto_DataType>(current_tensor_elem_type));
|
||||
if (input_tensor_elem_type != current_tensor_elem_type) {
|
||||
if (override_types) {
|
||||
DataType inferred_type = DataTypeUtils::ToType(input_type);
|
||||
// The "SetType" call will override the shape information to empty.
|
||||
// If the original tensor has shape information, need to set it back.
|
||||
if (Shape()) {
|
||||
auto old_shape = *Shape();
|
||||
SetType(inferred_type);
|
||||
SetShape(old_shape);
|
||||
} else {
|
||||
SetType(inferred_type);
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
|
||||
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
|
||||
static_cast<TensorProto_DataType>(current_tensor_elem_type));
|
||||
}
|
||||
}
|
||||
|
||||
if (utils::HasShape(input_tensor_type)) {
|
||||
auto& current_tensor_type = *current_type.mutable_tensor_type();
|
||||
|
|
@ -249,11 +264,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
|
|||
const auto& input_tensor_type = input_type.sparse_tensor_type();
|
||||
const auto input_tensor_elem_type = input_tensor_type.elem_type();
|
||||
const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type();
|
||||
|
||||
if (input_tensor_elem_type != current_tensor_elem_type) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
|
||||
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
|
||||
static_cast<TensorProto_DataType>(current_tensor_elem_type));
|
||||
if (override_types) {
|
||||
DataType inferred_type = DataTypeUtils::ToType(input_type);
|
||||
if (Shape()) {
|
||||
auto old_shape = *Shape();
|
||||
SetType(inferred_type);
|
||||
SetShape(old_shape);
|
||||
} else {
|
||||
SetType(inferred_type);
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
|
||||
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
|
||||
static_cast<TensorProto_DataType>(current_tensor_elem_type));
|
||||
}
|
||||
}
|
||||
|
||||
if (utils::HasShape(input_tensor_type)) {
|
||||
auto& current_tensor_type = *current_type.mutable_sparse_tensor_type();
|
||||
if (utils::HasShape(current_tensor_type)) {
|
||||
|
|
@ -275,11 +303,12 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) {
|
||||
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types,
|
||||
const logging::Logger& logger) {
|
||||
auto status = Status::OK();
|
||||
|
||||
if (utils::HasType(node_arg.node_arg_info_))
|
||||
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger);
|
||||
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, override_types, logger);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
@ -771,7 +800,7 @@ Graph::Graph(const Model& owning_model,
|
|||
// so we prefer the shape from the initializer
|
||||
name_to_type_map[tensor.name()] = t;
|
||||
if (matching_graph_input != nullptr) {
|
||||
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger));
|
||||
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, false, logger));
|
||||
}
|
||||
} else {
|
||||
// v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these.
|
||||
|
|
@ -1398,12 +1427,12 @@ bool FullyDefinedType(const TypeProto& type_proto) {
|
|||
// parameters are the Graph instance for the subgraph, the input types from the control flow node that contains
|
||||
// the subgraph, and the vector to write the output from the inferencing.
|
||||
using SubgraphInferencingFunc =
|
||||
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&)>;
|
||||
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&, const Graph::ResolveOptions&)>;
|
||||
|
||||
class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
|
||||
public:
|
||||
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func)
|
||||
: node_(node), graph_(graph), inferencing_func_(inferencing_func) {
|
||||
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options)
|
||||
: node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) {
|
||||
}
|
||||
|
||||
// Perform inferencing on the graph contained in GraphInferencer.
|
||||
|
|
@ -1413,7 +1442,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
|
|||
const std::vector<const TensorProto*>& /*input_data*/) override {
|
||||
std::vector<const TypeProto*> output_types;
|
||||
|
||||
auto status = inferencing_func_(node_, graph_, input_types, output_types);
|
||||
auto status = inferencing_func_(node_, graph_, input_types, output_types, options_);
|
||||
|
||||
if (status != Status::OK()) {
|
||||
fail_type_inference("Graph attribute inferencing failed: ", status.ErrorMessage());
|
||||
|
|
@ -1426,6 +1455,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
|
|||
const Node& node_;
|
||||
Graph& graph_;
|
||||
SubgraphInferencingFunc& inferencing_func_;
|
||||
const Graph::ResolveOptions& options_;
|
||||
};
|
||||
|
||||
// An implementation of the InferenceContext interface required by operator-specific
|
||||
|
|
@ -1436,10 +1466,12 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
|
|||
public:
|
||||
InferenceContextImpl(Node& node,
|
||||
SubgraphInferencingFunc subgraph_inferencing_func,
|
||||
const Graph& graph) noexcept
|
||||
const Graph& graph,
|
||||
const Graph::ResolveOptions& options) noexcept
|
||||
: node_(node),
|
||||
subgraph_inferencing_func_(subgraph_inferencing_func),
|
||||
graph_(graph) {
|
||||
graph_(graph),
|
||||
options_(options) {
|
||||
node_output_types_.resize(node.OutputDefs().size());
|
||||
}
|
||||
|
||||
|
|
@ -1500,7 +1532,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
|
|||
auto* subgraph = node_.GetMutableGraphAttribute(attribute_name);
|
||||
|
||||
if (subgraph) {
|
||||
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_);
|
||||
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_, options_);
|
||||
graph_inferencer = inferencer.get();
|
||||
graph_inferencers_.push_back(std::move(inferencer));
|
||||
} else {
|
||||
|
|
@ -1518,11 +1550,13 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
|
|||
SubgraphInferencingFunc subgraph_inferencing_func_;
|
||||
std::vector<std::unique_ptr<GraphInferencerImpl>> graph_inferencers_;
|
||||
const Graph& graph_;
|
||||
const Graph::ResolveOptions& options_;
|
||||
};
|
||||
|
||||
Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
||||
const std::vector<const TypeProto*>& input_types,
|
||||
std::vector<const TypeProto*>& output_types) {
|
||||
std::vector<const TypeProto*>& output_types,
|
||||
const Graph::ResolveOptions& options) {
|
||||
auto status = Status::OK();
|
||||
|
||||
output_types.clear();
|
||||
|
|
@ -1555,7 +1589,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
const auto& subgraph_input = *subgraph_inputs->at(i);
|
||||
|
||||
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
|
||||
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_);
|
||||
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, options.override_types, subgraph.logger_);
|
||||
if (!status.IsOK()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -1576,7 +1610,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
if (!subgraph_nodearg)
|
||||
continue;
|
||||
|
||||
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_);
|
||||
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, options.override_types, subgraph.logger_);
|
||||
if (!status.IsOK()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
|
||||
}
|
||||
|
|
@ -1588,8 +1622,6 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
|
||||
// now that we have handled the input types, do the type/shape inferencing for the subgraph
|
||||
// to flow the type/shape info through it
|
||||
// TODO: Handle override-type option correctly for subgraphs.
|
||||
Graph::ResolveOptions options;
|
||||
status = subgraph.PerformTypeAndShapeInferencing(options);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
|
|
@ -1695,7 +1727,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
|
|||
// Once that completes, the outputs from the node containing the subgraph will be updated, and the final values
|
||||
// returned here.
|
||||
SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes);
|
||||
InferenceContextImpl context(node, func, *this);
|
||||
InferenceContextImpl context(node, func, *this, options);
|
||||
|
||||
try {
|
||||
context.RunInferencing();
|
||||
|
|
|
|||
|
|
@ -574,6 +574,116 @@ TEST(Loop, InfiniteLoopTermination) {
|
|||
terminator_thread.join();
|
||||
}
|
||||
|
||||
// Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as
|
||||
// type/shape inferencing for subgraph to flow the type/shape info through
|
||||
// subgraph.PerformTypeAndShapeInferencing(options).
|
||||
// In this test, main graph has original input/expected output defined as "double" where the subgraph as "float".
|
||||
// Expectation is types should get propagated properly in subgraph and yield correct output
|
||||
//
|
||||
// TODO - when the input/output type in main graph is float16, extra Cast nodes will be added and type input type
|
||||
// will be changed by InsertCastTransformer for graph execution thus causes type mismatch failure.
|
||||
// Need to investigate how InsertCastTransformer works in future.
|
||||
TEST(Loop, SubgraphTypeOverride) {
|
||||
auto create_subgraph = [](const RunOptions&) {
|
||||
Model model("Loop subgraph", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
/*
|
||||
Inputs: iter_num, cond_in, fake_in, loop carried state variables.
|
||||
|
||||
iter_num_in cond_in fake_in [outer_scope_0]
|
||||
(unused) | | |
|
||||
[Identity] [Identity] [Identity]
|
||||
| | |
|
||||
cond_out fake_out loop_var_0_out
|
||||
*/
|
||||
|
||||
// graph inputs types.
|
||||
TypeProto int64_scalar;
|
||||
int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
TypeProto bool_scalar;
|
||||
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
|
||||
// graph inputs
|
||||
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
|
||||
auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
|
||||
auto& fake_in = graph.GetOrCreateNodeArg("fake_in", &float_tensor);
|
||||
|
||||
// outer scope value. need type but not shape.
|
||||
auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor);
|
||||
|
||||
// add so that we don't end up with it being considered a graph input
|
||||
graph.AddOuterScopeNodeArg("outer_scope_0");
|
||||
|
||||
// graph outputs
|
||||
auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar);
|
||||
auto& fake_out = graph.GetOrCreateNodeArg("fake_out", &float_tensor);
|
||||
auto& loop_var_0_out = graph.GetOrCreateNodeArg("loop_var_0_out", &float_tensor);
|
||||
|
||||
// cond_in -> cond_out
|
||||
{
|
||||
inputs = {&cond_in};
|
||||
outputs = {&cond_out};
|
||||
|
||||
graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs);
|
||||
}
|
||||
|
||||
// fake_in -> fake_out
|
||||
{
|
||||
inputs = {&fake_in};
|
||||
outputs = {&fake_out};
|
||||
|
||||
graph.AddNode("fake_in_identity", "Identity", "Forward fake_in to fake_out", inputs, outputs);
|
||||
}
|
||||
|
||||
// outer_scope_0 -> loop_var_0_out
|
||||
{
|
||||
inputs = {&outer_scope_0};
|
||||
outputs = {&loop_var_0_out};
|
||||
|
||||
graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
|
||||
}
|
||||
|
||||
graph.SetInputs({&iter_num_in, &cond_in, &fake_in});
|
||||
graph.SetOutputs({&cond_out, &fake_out, &loop_var_0_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return graph.ToGraphProto();
|
||||
};
|
||||
|
||||
LoopOpTester test{{}, create_subgraph};
|
||||
|
||||
test.AddInput<int64_t>("M", {1}, {1});
|
||||
test.AddInput<bool>("cond", {1}, {true});
|
||||
test.AddInput<double>("fake", {1}, {0.f});
|
||||
test.AddInput<double>("outer_scope_0", {1}, {kOuterNodeAddValue});
|
||||
|
||||
test.AddOutput<double>("loop_fake_final", {1}, {0.f});
|
||||
test.AddOutput<double>("loop_var_0_final", {1, 1}, {kOuterNodeAddValue});
|
||||
test.AddOutput<int64_t>("outer_scope_0_out", {1}, {int64_t(kOuterNodeAddValue)});
|
||||
|
||||
OrtRunOptions session_run_options;
|
||||
session_run_options.run_tag = "Loop.SubgraphTypeOverride";
|
||||
|
||||
Graph::ResolveOptions options;
|
||||
options.override_types = true;
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider}, &session_run_options, nullptr,
|
||||
ExecutionMode::ORT_SEQUENTIAL, {}, options);
|
||||
}
|
||||
|
||||
// Regression test that a subgraph input overrides an outer scope value of the same name.
|
||||
// Replicate issue from https://github.com/onnx/onnx/issues/2082
|
||||
TEST(Loop, SubgraphInputShadowsOuterScopeValue) {
|
||||
|
|
|
|||
|
|
@ -626,14 +626,15 @@ void OpTester::Run(
|
|||
const RunOptions* run_options,
|
||||
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
|
||||
ExecutionMode execution_mode,
|
||||
const CustomOutputVerifierFn& custom_output_verifier) {
|
||||
const CustomOutputVerifierFn& custom_output_verifier,
|
||||
const Graph::ResolveOptions& options) {
|
||||
SessionOptions so;
|
||||
so.session_logid = op_;
|
||||
so.session_log_verbosity_level = 1;
|
||||
so.execution_mode = execution_mode;
|
||||
so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off
|
||||
Run(so, expect_result, expected_failure_string, excluded_provider_types,
|
||||
run_options, execution_providers, custom_output_verifier);
|
||||
run_options, execution_providers, custom_output_verifier, options);
|
||||
}
|
||||
|
||||
void OpTester::Run(
|
||||
|
|
@ -643,7 +644,8 @@ void OpTester::Run(
|
|||
const std::unordered_set<std::string>& excluded_provider_types,
|
||||
const RunOptions* run_options,
|
||||
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
|
||||
const CustomOutputVerifierFn& custom_output_verifier) {
|
||||
const CustomOutputVerifierFn& custom_output_verifier,
|
||||
const Graph::ResolveOptions& options) {
|
||||
std::string cur_provider = "not set";
|
||||
try {
|
||||
#ifndef NDEBUG
|
||||
|
|
@ -660,12 +662,12 @@ void OpTester::Run(
|
|||
expect_result == ExpectResult::kExpectFailure) {
|
||||
// capture possible exceptions from shape inference for invalid testcase
|
||||
try {
|
||||
status = graph.Resolve();
|
||||
status = graph.Resolve(options);
|
||||
} catch (const std::exception& ex) {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
|
||||
}
|
||||
} else {
|
||||
status = graph.Resolve();
|
||||
status = graph.Resolve(options);
|
||||
}
|
||||
|
||||
if (!status.IsOK()) {
|
||||
|
|
|
|||
|
|
@ -413,7 +413,8 @@ class OpTester {
|
|||
const RunOptions* run_options = nullptr,
|
||||
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
|
||||
ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL,
|
||||
const CustomOutputVerifierFn& custom_output_verifier = {});
|
||||
const CustomOutputVerifierFn& custom_output_verifier = {},
|
||||
const Graph::ResolveOptions& resolve_options = {});
|
||||
|
||||
void Run(SessionOptions session_options,
|
||||
ExpectResult expect_result = ExpectResult::kExpectSuccess,
|
||||
|
|
@ -421,7 +422,8 @@ class OpTester {
|
|||
const std::unordered_set<std::string>& excluded_provider_types = {},
|
||||
const RunOptions* run_options = nullptr,
|
||||
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
|
||||
const CustomOutputVerifierFn& custom_output_verifier = {});
|
||||
const CustomOutputVerifierFn& custom_output_verifier = {},
|
||||
const Graph::ResolveOptions& resolve_options = {});
|
||||
|
||||
std::vector<MLValue> GetFetches() { return fetches_; }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue