From 5d3992f999ec8bc2d30ddf9ba8686b3054bce595 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 24 Nov 2018 17:51:35 +1000 Subject: [PATCH] Handle the Scan subgraph producing outputs with a symbolic dimension. If the output has a symbolic dimension * Infer the shape if it is a loop state variable as we have the input value, and the shape from the subgraph output must match * Use a temporary MLValue for the first subgraph execution if it is a subgraph output with a symbolic dimension. * After the first execution make the overall output shape concrete and allocate the full output buffer. * Use slices of the full output buffer for all other subgraph executions to avoid copies. Add unit test to validate. --- .../core/providers/cpu/controlflow/scan.cc | 305 +++++++++++++++--- .../test/providers/cpu/controlflow/if_test.cc | 4 +- .../providers/cpu/controlflow/scan_test.cc | 57 +++- .../test/providers/provider_test_utils.cc | 4 +- .../test/providers/provider_test_utils.h | 21 +- 5 files changed, 325 insertions(+), 66 deletions(-) diff --git a/onnxruntime/core/providers/cpu/controlflow/scan.cc b/onnxruntime/core/providers/cpu/controlflow/scan.cc index cbd6044c82..0504fe5a2f 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan.cc @@ -118,6 +118,59 @@ class LoopStateVariable { MLValue b_; }; +/* +Class that co-ordinates writes to slices of the overall Scan output. +It will directly update the data returned by OpKernelContextInternal.Output(i). +*/ +class OutputIterator { + public: + static Status Create(OpKernelContextInternal& context, + int output_index, + bool is_loop_state_var, + TensorShape final_shape, + std::unique_ptr& iterator) { + iterator.reset(new OutputIterator(context, output_index, is_loop_state_var, final_shape)); + return iterator->Initialize(); + } + + MLValue& operator*(); + OutputIterator& operator++(); + + void ZeroOutCurrent() { + auto* tensor = (**this).GetMutable(); + memset(tensor->MutableDataRaw(), 0, tensor->Size()); + } + + private: + OutputIterator(OpKernelContextInternal& context, + int output_index, + bool is_loop_state_var, + TensorShape final_shape); + + Status Initialize(); + Status AllocateFinalBuffer(); + Status MakeConcrete(); + + OpKernelContextInternal& context_; + const int output_index_; + std::vector dims_; + TensorShapeProto per_iteration_shape_; + TensorShape final_shape_; + bool is_loop_state_var_; + int64_t num_iterations_; + int64_t cur_iteration_; + + bool is_concrete_shape_; + std::vector::Iterator> slicer_iterators_; + std::vector::Iterator>::iterator cur_slicer_iterator_; + + // if shape is not concrete we need the first output to know the missing dimension before + // we can allocate final_output_mlvalue_ and use the slicers. + MLValue first_output_; + + MLValue* final_output_mlvalue_; +}; + class ScanImpl { public: ScanImpl(OpKernelContextInternal& context, @@ -135,10 +188,10 @@ class ScanImpl { private: // validate inputs and setup batch size and max sequence length. Status ValidateInput(); - Status ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim, + Status ValidateSubgraphInput(int start_input, int end_input, bool is_loop_state_var, const std::vector& graph_inputs); - Status AllocateOutput(int index, bool has_sequence_len); + Status AllocateOutput(int index, bool is_loop_state_var); Status AllocateOutputTensors(); Status CreateLoopStateVariables(std::vector>& loop_state_variables); @@ -147,7 +200,6 @@ class ScanImpl { Status IterateSequence(std::vector& loop_state_variables, ConstTensorSlicerIterators& scan_input_stream_iterators, - MutableTensorSlicerIterators& scan_output_stream_iterators, int64_t seq_length); OpKernelContextInternal& context_; @@ -166,6 +218,7 @@ class ScanImpl { std::vector sequence_lens_; std::vector subgraph_output_names_; + std::vector> output_iterators_; std::unordered_map implicit_inputs_; }; @@ -249,6 +302,149 @@ void LoopStateVariable::Next() { ++iteration_num_; } +static Status MakeShapeConcrete(const TensorShape& per_iteration_shape, TensorShape& final_shape) { + auto num_dims_per_iteration = per_iteration_shape.NumDimensions(); + auto final_shape_offset = final_shape.NumDimensions() - num_dims_per_iteration; + for (size_t i = 0; i < num_dims_per_iteration; ++i) { + auto existing_value = final_shape[i + final_shape_offset]; + if (existing_value == -1) { + final_shape[i + final_shape_offset] = per_iteration_shape[i]; + } else { + if (existing_value != per_iteration_shape[i]) { + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Mismatch between expected shape and shape from first output", + final_shape, " is not compatible with ", per_iteration_shape); + } + } + } + + return Status::OK(); +} + +OutputIterator::OutputIterator(OpKernelContextInternal& context, + int output_index, + bool is_loop_state_var, + TensorShape final_shape) + : context_{context}, + output_index_{output_index}, + is_loop_state_var_{is_loop_state_var}, + final_shape_{final_shape}, + cur_iteration_{0} { + is_concrete_shape_ = final_shape_.Size() >= 0; + + // there are one or two dimensions being iterated depending on whether it's a loop state variable or scan input. + auto num_iteration_dims = is_loop_state_var_ ? 1 : 2; + num_iterations_ = final_shape_.Slice(0, num_iteration_dims).Size(); +} + +Status OutputIterator::Initialize() { + Status status = Status::OK(); + + if (is_loop_state_var_ && !is_concrete_shape_) { + // copy the shape from the input initial value which will have a concrete shape. + auto* input = context_.Input(output_index_ + 1); // +1 to skip the sequence_len input + status = MakeShapeConcrete(input->Shape(), final_shape_); + ONNXRUNTIME_RETURN_IF_ERROR(status); + + is_concrete_shape_ = true; + } + + if (is_concrete_shape_) { + status = AllocateFinalBuffer(); + ONNXRUNTIME_RETURN_IF_ERROR(status); + } else { + // use first_output_ + } + + return Status::OK(); +} + +Status OutputIterator::AllocateFinalBuffer() { + // make sure a single buffer for the full output is created upfront. + // we slice this into per-iteration pieces in Execute using MLValueTensorSlicer. + auto* tensor = context_.Output(output_index_, final_shape_); + + if (!tensor) + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for output #", output_index_); + + // get the output tensor we just created as an MLValue + final_output_mlvalue_ = context_.GetOutputMLValue(output_index_); + + if (is_loop_state_var_) { + // only one entry is required as we slice on a single dimension + slicer_iterators_.push_back(MLValueTensorSlicer::Create(*final_output_mlvalue_).begin()); + } else { + auto batch_size = final_shape_[0]; + for (int i = 0; i < batch_size; ++i) { + // the slicer handles the sequence dimension (dim 1) so create an entry for each batch + slicer_iterators_.push_back(MLValueTensorSlicer::Create(*final_output_mlvalue_, 1, i).begin()); + } + } + + cur_slicer_iterator_ = slicer_iterators_.begin(); + + return Status::OK(); +} + +Status OutputIterator::MakeConcrete() { + ONNXRUNTIME_ENFORCE(first_output_.IsAllocated(), "First usage of OutputIterator did not result in any output."); + Status status = Status::OK(); + + auto& tensor = first_output_.Get(); + auto& tensor_shape = tensor.Shape(); + + // update the final shape + status = MakeShapeConcrete(tensor_shape, final_shape_); + ONNXRUNTIME_RETURN_IF_ERROR(status); + + is_concrete_shape_ = true; + status = AllocateFinalBuffer(); + ONNXRUNTIME_RETURN_IF_ERROR(status); + + // copy first output to final buffer + auto input_span = gsl::make_span(static_cast(tensor.DataRaw()), tensor.Size()); + + auto output = (**this).GetMutable(); + auto output_span = gsl::make_span(static_cast(output->MutableDataRaw()), output->Size()); + + gsl::copy(input_span, output_span); + + // release the MLValue we used for the first output + first_output_ = {}; + + return status; +} + +MLValue& OutputIterator::operator*() { + ONNXRUNTIME_ENFORCE(cur_iteration_ < num_iterations_); + + if (is_concrete_shape_) + return **cur_slicer_iterator_; + else + return first_output_; +} + +OutputIterator& OutputIterator::operator++() { + if (cur_iteration_ < num_iterations_) { + if (!is_concrete_shape_) { + // we should have an output now, so convert to using the overall output buffer and slicers + auto status = MakeConcrete(); + ONNXRUNTIME_ENFORCE(status.IsOK(), status.ErrorMessage()); + } + + ++cur_iteration_; + + // if not a loop state var, see if we just finished the current sequence (dim 1) + if (!is_loop_state_var_ && cur_iteration_ % final_shape_[1] == 0) { + ++cur_slicer_iterator_; + } else { + ++(*cur_slicer_iterator_); + } + } + + return *this; +} + ScanImpl::ScanImpl(OpKernelContextInternal& context, const SessionState& session_state, int64_t num_scan_inputs, @@ -258,7 +454,7 @@ ScanImpl::ScanImpl(OpKernelContextInternal& context, subgraph_{*session_state.GetGraphViewer()}, directions_{directions}, implicit_inputs_{context_.GetImplicitInputs()} { - //optional first input so may be nullptr + // optional first input so may be nullptr sequence_lens_tensor_ = context.Input(0); num_variadic_inputs_ = context_.NumVariadicInputs(1); @@ -271,12 +467,12 @@ Status ScanImpl::Initialize() { auto status = ValidateInput(); ONNXRUNTIME_RETURN_IF_ERROR(status); - auto& graph_outputs = subgraph_.GetOutputs(); - subgraph_output_names_.reserve(graph_outputs.size()); + auto& subgraph_outputs = subgraph_.GetOutputs(); + subgraph_output_names_.reserve(subgraph_outputs.size()); // save list of subgraph output names in their provided order to use when fetching the results // from each subgraph execution. the Scan outputs will match this order. - for (auto& output : graph_outputs) { + for (auto& output : subgraph_outputs) { subgraph_output_names_.push_back(output->Name()); } @@ -301,9 +497,10 @@ static const MLValue& GetSubgraphInputMLValue(const OpKernelContextInternal& con } // Validate that the subgraph input has valid shapes -Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim, +Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool is_loop_state_var, const std::vector& graph_inputs) { // first dim is batch size. optional sequence dim. dim/s for the data + bool has_seq_len_dim = !is_loop_state_var; auto min_dims_required = has_seq_len_dim ? 3 : 2; for (int i = start_input; i < end_input; ++i) { @@ -355,11 +552,11 @@ Status ScanImpl::ValidateInput() { } // process any loop state variables, which will set the batch size - auto status = ValidateSubgraphInput(0, num_loop_state_variables_, false, graph_inputs); + auto status = ValidateSubgraphInput(0, num_loop_state_variables_, true, graph_inputs); ONNXRUNTIME_RETURN_IF_ERROR(status); // process the scan inputs. sets/validates batch size and sequence length - status = ValidateSubgraphInput(num_loop_state_variables_, num_variadic_inputs_, true, graph_inputs); + status = ValidateSubgraphInput(num_loop_state_variables_, num_variadic_inputs_, false, graph_inputs); ONNXRUNTIME_RETURN_IF_ERROR(status); if (sequence_lens_tensor_ != nullptr) { @@ -386,11 +583,12 @@ Status ScanImpl::ValidateInput() { return Status::OK(); } -Status ScanImpl::AllocateOutput(int index, bool has_sequence_len_dimension) { +Status ScanImpl::AllocateOutput(int index, bool is_loop_state_var) { // use the shape from the subgraph output. we require this to be specified in the model or inferable. auto& graph_outputs = subgraph_.GetOutputs(); auto* graph_output = graph_outputs.at(index); auto* graph_output_shape = graph_output->Shape(); + if (!graph_output_shape) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Subgraph must have the shape set for all outputs but ", graph_output->Name(), " did not."); @@ -404,24 +602,16 @@ Status ScanImpl::AllocateOutput(int index, bool has_sequence_len_dimension) { scan_output_dims.push_back(batch_size_); - if (has_sequence_len_dimension) { + if (!is_loop_state_var) { scan_output_dims.push_back(max_sequence_len_); } scan_output_dims.insert(scan_output_dims.cend(), graph_output_dims.cbegin(), graph_output_dims.cend()); - // make sure a single buffer for the full output is created upfront. - // we slice this into per-iteration pieces in Execute using MLValueTensorSlicer. - auto* tensor = context_.Output(index, TensorShape(scan_output_dims)); + std::unique_ptr output_iter; + OutputIterator::Create(context_, index, is_loop_state_var, TensorShape(scan_output_dims), output_iter); - if (!tensor) - return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name()); - - // zero out the output so that any short sequences have deterministic values in unused slots. - // strictly speaking this isn't required, and alternatively we could fill with zeros when we - // encounter a short sequence and are creating output, but one memset is easy, involves - // less code complexity, and should be relatively cheap. - memset(tensor->MutableDataRaw(), 0, tensor->Size()); + output_iterators_.push_back(std::move(output_iter)); return Status::OK(); } @@ -435,17 +625,13 @@ Status ScanImpl::AllocateOutputTensors() { " outputs but Scan expects ", num_variadic_outputs_); } - // TODO: Need to handle shape/type inference for subgraphs. - // For now copy shape from subgraph output and expand based on batch size and sequence length - for (int i = 0; i < num_loop_state_variables_; ++i) { - const bool has_sequence_len_dimension = false; // loop state variables don't have a sequence_len dimension; - status = AllocateOutput(i, has_sequence_len_dimension); + status = AllocateOutput(i, true); ONNXRUNTIME_RETURN_IF_ERROR(status); } for (int i = num_loop_state_variables_, end = num_variadic_outputs_; i < end; ++i) { - status = AllocateOutput(i, true); + status = AllocateOutput(i, false); ONNXRUNTIME_RETURN_IF_ERROR(status); } @@ -461,9 +647,7 @@ Status ScanImpl::CreateLoopStateVariables(std::vector::Iterator> loop_state_input_iterators; - std::vector::Iterator> loop_state_output_iterators; loop_state_input_iterators.reserve(num_loop_state_variables_); - loop_state_output_iterators.reserve(num_loop_state_variables_); // create the input and output slice iterator for each loop state variable. for (int i = 0; i < num_loop_state_variables_; ++i) { @@ -473,7 +657,6 @@ Status ScanImpl::CreateLoopStateVariables(std::vector::Create(mlvalue).begin()); - loop_state_output_iterators.push_back(MLValueTensorSlicer::Create(*p_mlvalue).begin()); } batch_loop_state_variables.clear(); @@ -490,7 +673,7 @@ Status ScanImpl::CreateLoopStateVariables(std::vector::Iterator> scan_output_stream_iterators; - scan_output_stream_iterators.reserve(num_variadic_outputs_); - - for (int i = num_loop_state_variables_, end = num_variadic_outputs_; i < end; ++i) { - MLValue* p_mlvalue = context_.GetOutputMLValue(i); - ONNXRUNTIME_ENFORCE(p_mlvalue, "Output MLValue has not been created for output ", i); - - scan_output_stream_iterators.push_back(MLValueTensorSlicer::Create(*p_mlvalue, 1, b).begin()); - } - // Call the subgraph for each item in the sequence status = IterateSequence(batch_loop_state_variables[b], scan_input_stream_iterators, - scan_output_stream_iterators, sequence_lens_[b]); ONNXRUNTIME_RETURN_IF_ERROR(status); @@ -558,7 +729,6 @@ Status ScanImpl::Execute() { Status ScanImpl::IterateSequence(std::vector& loop_state_variables, ConstTensorSlicerIterators& scan_input_stream_iterators, - MutableTensorSlicerIterators& scan_output_stream_iterators, int64_t seq_length) { Status status = Status::OK(); auto& graph_inputs = subgraph_.GetInputs(); @@ -575,9 +745,8 @@ Status ScanImpl::IterateSequence(std::vector& loop_state_vari feeds[entry.first] = *entry.second; } - // as we fill all the outputs with 0 initially, just iterate seq_length not max_seq_length_ - // as we don't need to pad the output for a short sequence here. - for (int64_t seq_no = 0; seq_no < seq_length; ++seq_no) { + int64_t seq_no = 0; + for (; seq_no < seq_length; ++seq_no) { for (int input = 0; input < num_variadic_inputs_; ++input) { // the ordering of the Scan inputs should match the ordering of the subgraph inputs auto name = graph_inputs[input]->Name(); @@ -596,15 +765,24 @@ Status ScanImpl::IterateSequence(std::vector& loop_state_vari fetches.clear(); + bool copy_fetch_to_iter = false; + for (int output = 0, end = num_variadic_outputs_; output < end; ++output) { if (output < num_loop_state_variables_) { // add loop state variable output fetches.push_back(loop_state_variables[output].Output()); } else { - // add sliced output - auto& iterator = scan_output_stream_iterators[output - num_loop_state_variables_]; - fetches.push_back(*iterator); - ++iterator; + // add MLValue from sliced output + auto& iterator = *output_iterators_[output]; + auto& mlvalue = *iterator; + fetches.push_back(mlvalue); + + // If there is a dynamic shape in an output we need to copy it back to the OutputIterator + // so it can setup the overall output and avoid copies for all other output values. + // The mlvalue in the iterator will point to data once we have the overall output initialized. + // Check current value as we don't want to unset copy_fetch_to_iter if it is true. + if (!copy_fetch_to_iter) + copy_fetch_to_iter = (seq_no == 0) && (mlvalue.IsAllocated() == false); } } @@ -620,6 +798,27 @@ Status ScanImpl::IterateSequence(std::vector& loop_state_vari // cycle the LoopStateVariable input/output in preparation for the next iteration std::for_each(loop_state_variables.begin(), loop_state_variables.end(), [](LoopStateVariable& v) { v.Next(); }); + + // and move the output iterators. + for (int output = num_loop_state_variables_; output < num_variadic_outputs_; ++output) { + auto& iterator = *output_iterators_[output]; + + // copy the data from fetches to the iterator so it can setup the overall output + if (copy_fetch_to_iter && (*iterator).IsAllocated() == false) { + *iterator = fetches[output]; + } + + ++iterator; + } + } + + // zero out any remaining values in the sequence + for (; seq_length < max_sequence_len_; ++seq_length) { + for (int output = num_loop_state_variables_; output < num_variadic_outputs_; ++output) { + auto& iterator = *output_iterators_[output]; + iterator.ZeroOutCurrent(); + ++iterator; + } } return status; diff --git a/onnxruntime/test/providers/cpu/controlflow/if_test.cc b/onnxruntime/test/providers/cpu/controlflow/if_test.cc index b53b451a0d..858d9c550f 100644 --- a/onnxruntime/test/providers/cpu/controlflow/if_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/if_test.cc @@ -16,7 +16,7 @@ namespace test { struct RunOptions { bool include_dim_values_in_main_graph = false; - bool symbolic_dim_values_in_main_graph = false; + int symbolic_dim_value_in_main_graph = -1; bool include_dim_values_in_subgraph = true; }; @@ -181,7 +181,7 @@ void RunTest(bool condition_value, IfOpTester test{options}; test.AddShapeToTensorData(options.include_dim_values_in_main_graph, - options.symbolic_dim_values_in_main_graph); + options.symbolic_dim_value_in_main_graph); // add the main graph inputs and outputs. // we will handle the 'If' inputs in the AddNodes override, and as 'If' is the last node diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 965b7295d8..856bf4e145 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -261,8 +261,6 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen ScanOpTester test; - test.AddShapeToTensorData(options.include_dim_values_in_main_graph); - test.AddAttribute("body", proto); test.AddAttribute("num_scan_inputs", 2); @@ -277,6 +275,8 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen test.AddInput("sequence_lens", sequence_lens_dims, *sequence_lens); } + test.AddShapeToTensorData(options.include_dim_values_in_main_graph); + test.AddInput("scan_loop_state_in_0", {batch_size, 1}, loop_state_in_0); std::vector input_shape{batch_size, max_sequence_len, input_size}; @@ -665,5 +665,58 @@ TEST(Scan, MixedTypeInputs) { test.Run(); } +TEST(Scan, UnknownDimInSubgraphOutput) { + Model model("ScanBody"); + auto& graph = model.MainGraph(); + + TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("param"); + TypeProto int_tensor; + int_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + int_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("param"); + + auto& state_in_1 = graph.GetOrCreateNodeArg("state_in_1", &float_tensor); + auto& scan_in_1 = graph.GetOrCreateNodeArg("scan_in_1", &float_tensor); + + auto& state_out_1 = graph.GetOrCreateNodeArg("state_out_1", &float_tensor); + auto& scan_out_1 = graph.GetOrCreateNodeArg("scan_out_1", &float_tensor); + + graph.AddNode("node1", "Identity", "Copy state_in_1 to scan_out_1", {&state_in_1}, {&scan_out_1}); + graph.AddNode("node2", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1}); + + graph.SetInputOrder({&state_in_1, &scan_in_1}); + graph.SetOutputOrder({&state_out_1, &scan_out_1}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + auto& scan_body = graph.ToGraphProto(); + + // Construct and run scan test + ScanOpTester test; + + int64_t batch_size = 1, sequence_len = 3, input_size = 1; + std::vector seq_shape{batch_size, sequence_len, input_size}; + std::vector state_shape{batch_size, input_size}; + + test.AddAttribute("body", scan_body); + test.AddAttribute("num_scan_inputs", 1); + + // we add a symbolic dimension to bot the initial state and the scan input so we test the path that handles loop + // state variables (prior to execution) and the path that handles subgraph outputs (post first execution). + // Note that we cross the values over in the subgraph, so the symbolic dimension in + // initial_state_1 affects scan_out_1, and the symbolic dimension in scan_input_1 affects state_out_1. + test.AddMissingOptionalInput(); + test.AddShapeToTensorData(true, 1); // add shape and symbolic dim in dim 1 for initial_state_1 + test.AddInput("initial_state_1", state_shape, {0.0}); + test.AddShapeToTensorData(true, 2); // add shape and symbolic dim in dim 2 for scan_input_1 + test.AddInput("scan_input_1", seq_shape, {1.0, 2.0, 3.0}); + + test.AddOutput("final_state_1", state_shape, {3.0}); + test.AddOutput("scan_output_1", seq_shape, {0.0, 1.0, 2.0}); + + test.Run(); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 2c118f1372..7815ab4d37 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -407,7 +407,9 @@ void OpTester::Run(ExpectResult expect_result, const auto& expected_shape = expected_data.data_.Get().Shape(); EXPECT_TRUE(inferred_dims.size() == expected_shape.NumDimensions()); for (int d = 0; d < inferred_dims.size(); ++d) { - EXPECT_EQ(expected_shape[d], inferred_dims[d]); + // check equal unless the input involved a symbolic dimension + if (inferred_dims[d] != -1) + EXPECT_EQ(expected_shape[d], inferred_dims[d]) << "Output idx = " << idx << " dim = " << d; } } Check(expected_data, mlvalue.Get(), provider_type); diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 0ce06aea34..7c6abcebcc 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -91,7 +91,11 @@ struct TTypeProto : ONNX_NAMESPACE::TypeProto { if (shape) { auto mutable_shape = mutable_tensor_type()->mutable_shape(); for (auto i : *shape) { - mutable_shape->add_dim()->set_dim_value(i); + auto* mutable_dim = mutable_shape->add_dim(); + if (i != -1) + mutable_dim->set_dim_value(i); + else + mutable_dim->set_dim_param("symbolic"); } } } @@ -145,10 +149,11 @@ class OpTester { // Set whether the NodeArg created by AddInput/AddOutput should include shape information // for Tensor types. If not added, shape inferencing should resolve. If added, shape inferencing - // should validate. Default is to not add. - OpTester& AddShapeToTensorData(bool add_shape = true, bool add_symbolic_dim = false) { + // should validate. Default is to not add. + // Additionally a symbolic dimension will be added if symbolic_dim matches a dimension in the input. + OpTester& AddShapeToTensorData(bool add_shape = true, int symbolic_dim = -1) { add_shape_to_tensor_data_ = add_shape; - add_symbolic_dim_to_tensor_data_ = add_symbolic_dim; + add_symbolic_dim_to_tensor_data_ = symbolic_dim; return *this; } @@ -268,7 +273,7 @@ class OpTester { ONNXRUNTIME_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", shape.Size()); - auto allocator = ::onnxruntime::test::AllocatorManager::Instance().GetAllocator(CPU); + auto allocator = test::AllocatorManager::Instance().GetAllocator(CPU); auto size_in_bytes = values_count * sizeof(T); void* buffer = allocator->Alloc(size_in_bytes); auto p_tensor = std::make_unique(DataTypeImpl::GetType(), @@ -283,8 +288,8 @@ class OpTester { } std::vector dims_for_proto{dims}; - if (add_symbolic_dim_to_tensor_data_ && !dims.empty()) { - dims_for_proto[0] = -1; + if (add_symbolic_dim_to_tensor_data_ >= 0 && dims.size() > add_symbolic_dim_to_tensor_data_) { + dims_for_proto[add_symbolic_dim_to_tensor_data_] = -1; } TTypeProto type_proto(add_shape_to_tensor_data_ ? &dims_for_proto : nullptr); @@ -302,7 +307,7 @@ class OpTester { const char* domain_; int opset_version_; bool add_shape_to_tensor_data_ = true; - bool add_symbolic_dim_to_tensor_data_ = false; + int add_symbolic_dim_to_tensor_data_ = -1; std::vector input_data_; std::vector output_data_; std::vector initializer_index_;