mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Support scalars (zero dimensions) in Scan by allowing the parameters to Scan to have no dimension for the input data.
This commit is contained in:
parent
c7513e676f
commit
03d7d25989
3 changed files with 51 additions and 26 deletions
|
|
@ -15,8 +15,8 @@ MLValueTensorSlicer<T> MLValueTensorSlicer<T>::Create(T& mlvalue, int64_t slice_
|
|||
ONNXRUNTIME_ENFORCE(mlvalue.IsAllocated(), "MLValue has not been allocated so can't be sliced.");
|
||||
|
||||
auto& tensor_shape{mlvalue.template Get<Tensor>().Shape()};
|
||||
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) > slice_dimension,
|
||||
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
|
||||
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) >= slice_dimension,
|
||||
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
|
||||
|
||||
auto dim0_size = tensor_shape[0];
|
||||
ONNXRUNTIME_ENFORCE(dim0_offset < dim0_size, "Invalid dim0_offset of ", dim0_offset, ". Dimension 0 is ", dim0_size);
|
||||
|
|
|
|||
|
|
@ -303,8 +303,9 @@ static const MLValue& GetSubgraphInputMLValue(const OpKernelContextInternal& con
|
|||
// Validate that the subgraph input has valid shapes
|
||||
Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim,
|
||||
const std::vector<const NodeArg*>& graph_inputs) {
|
||||
// first dim is batch size. optional sequence dim. dim/s for the data
|
||||
auto min_dims_required = has_seq_len_dim ? 3 : 2;
|
||||
// first dim is batch size. optional sequence dim. dim/s for the data.
|
||||
// if there is no dim for the data treat it as a scalar.
|
||||
auto min_dims_required = has_seq_len_dim ? 2 : 1;
|
||||
|
||||
for (int i = start_input; i < end_input; ++i) {
|
||||
auto& input_tensor = GetSubgraphInputTensor(context_, i);
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ struct RunOptions {
|
|||
bool include_dim_values_in_subgraph = true;
|
||||
bool include_types_in_subgraph = true;
|
||||
bool include_outer_scope_add = false;
|
||||
bool scalar_loop_state_value = false;
|
||||
bool add_bad_shape = false;
|
||||
};
|
||||
|
||||
|
|
@ -37,13 +38,13 @@ class ScanOpTester : public OpTester {
|
|||
// add outer_scope_0 node. push the value through an extra Identity node as a Constant gets lifted into an
|
||||
// initializer which results in different treatment by the allocation planner
|
||||
{
|
||||
TypeProto float_scalar;
|
||||
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
TypeProto float_single_value;
|
||||
float_single_value.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
auto mutable_dim = float_single_value.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
mutable_dim->set_dim_value(1);
|
||||
|
||||
{
|
||||
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_scalar);
|
||||
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_single_value);
|
||||
auto* constant = graph.AddNode("outer_scope_constant", "Constant", "Constant with value kOuterNodeAddValue",
|
||||
{}, {&outer_scope_constant});
|
||||
|
||||
|
|
@ -54,7 +55,7 @@ class ScanOpTester : public OpTester {
|
|||
|
||||
constant->AddAttribute("value", value_tensor);
|
||||
|
||||
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_scalar);
|
||||
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_single_value);
|
||||
graph.AddNode("outer_scope_id", "Identity", "Identity for outer_scope_0",
|
||||
{&outer_scope_constant}, {&outer_scope_node_arg});
|
||||
}
|
||||
|
|
@ -66,7 +67,7 @@ class ScanOpTester : public OpTester {
|
|||
};
|
||||
|
||||
static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string& failure_message) {
|
||||
bool include_shapes = options.include_dim_values_in_subgraph;
|
||||
bool include_dim_values = options.include_dim_values_in_subgraph;
|
||||
bool include_types = options.include_types_in_subgraph;
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
|
|
@ -94,21 +95,27 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
|
|||
inputs = {};
|
||||
outputs = {};
|
||||
|
||||
TypeProto float_scalar;
|
||||
TypeProto float_input;
|
||||
// inputs must have type information and a rank
|
||||
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
if (include_shapes)
|
||||
mutable_dim->set_dim_value(1);
|
||||
float_input.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
auto mutable_shape = float_input.mutable_tensor_type()->mutable_shape();
|
||||
if (options.scalar_loop_state_value) {
|
||||
// no dims
|
||||
} else {
|
||||
auto mutable_dim = mutable_shape->add_dim(); // set rank
|
||||
if (include_dim_values)
|
||||
mutable_dim->set_dim_value(1);
|
||||
}
|
||||
|
||||
{
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_scalar);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_input);
|
||||
outputs.push_back(&output_arg);
|
||||
|
||||
auto* constant = graph.AddNode("constant", "Constant", "Constant with value 1", inputs, outputs);
|
||||
|
||||
TensorProto value_tensor;
|
||||
value_tensor.add_dims(1);
|
||||
if (!options.scalar_loop_state_value)
|
||||
value_tensor.add_dims(1);
|
||||
value_tensor.add_float_data(1.f);
|
||||
value_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);
|
||||
|
||||
|
|
@ -118,7 +125,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
|
|||
inputs = outputs; // start with output from Constant node
|
||||
outputs = {};
|
||||
|
||||
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_scalar);
|
||||
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_input);
|
||||
inputs.push_back(&input_arg);
|
||||
|
||||
TypeProto loop_state_output_tensor;
|
||||
|
|
@ -128,15 +135,17 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
|
|||
// it has to come from here.
|
||||
bool type_and_shape_required = options.include_dim_values_in_main_graph == false;
|
||||
|
||||
if (include_shapes || type_and_shape_required)
|
||||
loop_state_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
if (include_dim_values || type_and_shape_required) {
|
||||
mutable_shape = loop_state_output_tensor.mutable_tensor_type()->mutable_shape();
|
||||
if (!options.scalar_loop_state_value)
|
||||
mutable_shape->add_dim()->set_dim_value(1);
|
||||
}
|
||||
|
||||
TypeProto* type_proto = include_types || type_and_shape_required ? &loop_state_output_tensor : nullptr;
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("loop_state_out_1", type_proto);
|
||||
outputs.push_back(&output_arg);
|
||||
|
||||
auto* add = graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
|
||||
(void)add;
|
||||
graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
|
||||
}
|
||||
|
||||
// subgraph with multiple inputs and outputs to test variadic behaviour.
|
||||
|
|
@ -152,7 +161,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
|
|||
// inputs must have type information and rank, but dimension can have no value if we're not providing shape info.
|
||||
concat_input_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
auto mutable_dim = concat_input_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
if (include_shapes) {
|
||||
if (include_dim_values) {
|
||||
mutable_dim->set_dim_value(2);
|
||||
|
||||
if (options.add_bad_shape) {
|
||||
|
|
@ -168,7 +177,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
|
|||
// one output from concatenate of {4} tensor
|
||||
TypeProto concat_output_tensor;
|
||||
concat_output_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
if (include_shapes)
|
||||
if (include_dim_values)
|
||||
concat_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4);
|
||||
|
||||
TypeProto* type_proto = include_types ? &concat_output_tensor : nullptr;
|
||||
|
|
@ -277,13 +286,18 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen
|
|||
test.AddInput<int64_t>("sequence_lens", sequence_lens_dims, *sequence_lens);
|
||||
}
|
||||
|
||||
test.AddInput<float>("scan_loop_state_in_0", {batch_size, 1}, loop_state_in_0);
|
||||
std::vector<int64_t> loop_state_shape{batch_size};
|
||||
if (!options.scalar_loop_state_value) {
|
||||
loop_state_shape.push_back(1);
|
||||
}
|
||||
|
||||
test.AddInput<float>("scan_loop_state_in_0", loop_state_shape, loop_state_in_0);
|
||||
|
||||
std::vector<int64_t> input_shape{batch_size, max_sequence_len, input_size};
|
||||
test.AddInput<float>("scan_input_0", input_shape, input_0);
|
||||
test.AddInput<float>("scan_input_1", input_shape, input_1);
|
||||
|
||||
test.AddOutput<float>("scan_loop_state_out_0", {batch_size, 1}, loop_state_out_0);
|
||||
test.AddOutput<float>("scan_loop_state_out_0", loop_state_shape, loop_state_out_0);
|
||||
|
||||
std::vector<int64_t> output_shape{batch_size, max_sequence_len, 1};
|
||||
test.AddOutput<float>("scan_output_0", output_shape, output_0);
|
||||
|
|
@ -353,6 +367,16 @@ TEST(Scan, ShortSequenceOneInBatchOneLoopStateVar_NoShapeInMainGraph_NoTypeAndSh
|
|||
ShortSequenceOneInBatchOneLoopStateVar(options);
|
||||
}
|
||||
|
||||
TEST(Scan, OnnxScalarLoopState) {
|
||||
RunOptions options{};
|
||||
options.include_dim_values_in_main_graph = true;
|
||||
options.include_types_in_subgraph = false;
|
||||
options.include_dim_values_in_subgraph = false;
|
||||
options.scalar_loop_state_value = true;
|
||||
|
||||
ShortSequenceOneInBatchOneLoopStateVar(options);
|
||||
}
|
||||
|
||||
// test when there is an operator in the subgraph that uses a value coming from outer scope
|
||||
TEST(Scan, OuterScopeAccess_NoShapeInMainGraph_TypeAndShapeInSubgraph) {
|
||||
RunOptions options{};
|
||||
|
|
|
|||
Loading…
Reference in a new issue