Support scalars (zero dimensions) in Scan by allowing the parameters to Scan to have no dimension for the input data.

This commit is contained in:
Scott McKay 2018-11-26 10:23:26 +10:00
parent c7513e676f
commit 03d7d25989
3 changed files with 51 additions and 26 deletions

View file

@ -15,8 +15,8 @@ MLValueTensorSlicer<T> MLValueTensorSlicer<T>::Create(T& mlvalue, int64_t slice_
ONNXRUNTIME_ENFORCE(mlvalue.IsAllocated(), "MLValue has not been allocated so can't be sliced.");
auto& tensor_shape{mlvalue.template Get<Tensor>().Shape()};
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) > slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) >= slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
auto dim0_size = tensor_shape[0];
ONNXRUNTIME_ENFORCE(dim0_offset < dim0_size, "Invalid dim0_offset of ", dim0_offset, ". Dimension 0 is ", dim0_size);

View file

@ -303,8 +303,9 @@ static const MLValue& GetSubgraphInputMLValue(const OpKernelContextInternal& con
// Validate that the subgraph input has valid shapes
Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim,
const std::vector<const NodeArg*>& graph_inputs) {
// first dim is batch size. optional sequence dim. dim/s for the data
auto min_dims_required = has_seq_len_dim ? 3 : 2;
// first dim is batch size. optional sequence dim. dim/s for the data.
// if there is no dim for the data treat it as a scalar.
auto min_dims_required = has_seq_len_dim ? 2 : 1;
for (int i = start_input; i < end_input; ++i) {
auto& input_tensor = GetSubgraphInputTensor(context_, i);

View file

@ -17,6 +17,7 @@ struct RunOptions {
bool include_dim_values_in_subgraph = true;
bool include_types_in_subgraph = true;
bool include_outer_scope_add = false;
bool scalar_loop_state_value = false;
bool add_bad_shape = false;
};
@ -37,13 +38,13 @@ class ScanOpTester : public OpTester {
// add outer_scope_0 node. push the value through an extra Identity node as a Constant gets lifted into an
// initializer which results in different treatment by the allocation planner
{
TypeProto float_scalar;
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
TypeProto float_single_value;
float_single_value.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_single_value.mutable_tensor_type()->mutable_shape()->add_dim();
mutable_dim->set_dim_value(1);
{
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_scalar);
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_single_value);
auto* constant = graph.AddNode("outer_scope_constant", "Constant", "Constant with value kOuterNodeAddValue",
{}, {&outer_scope_constant});
@ -54,7 +55,7 @@ class ScanOpTester : public OpTester {
constant->AddAttribute("value", value_tensor);
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_scalar);
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_single_value);
graph.AddNode("outer_scope_id", "Identity", "Identity for outer_scope_0",
{&outer_scope_constant}, {&outer_scope_node_arg});
}
@ -66,7 +67,7 @@ class ScanOpTester : public OpTester {
};
static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string& failure_message) {
bool include_shapes = options.include_dim_values_in_subgraph;
bool include_dim_values = options.include_dim_values_in_subgraph;
bool include_types = options.include_types_in_subgraph;
std::vector<NodeArg*> inputs;
@ -94,21 +95,27 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = {};
outputs = {};
TypeProto float_scalar;
TypeProto float_input;
// inputs must have type information and a rank
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes)
mutable_dim->set_dim_value(1);
float_input.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_shape = float_input.mutable_tensor_type()->mutable_shape();
if (options.scalar_loop_state_value) {
// no dims
} else {
auto mutable_dim = mutable_shape->add_dim(); // set rank
if (include_dim_values)
mutable_dim->set_dim_value(1);
}
{
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_scalar);
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_input);
outputs.push_back(&output_arg);
auto* constant = graph.AddNode("constant", "Constant", "Constant with value 1", inputs, outputs);
TensorProto value_tensor;
value_tensor.add_dims(1);
if (!options.scalar_loop_state_value)
value_tensor.add_dims(1);
value_tensor.add_float_data(1.f);
value_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);
@ -118,7 +125,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = outputs; // start with output from Constant node
outputs = {};
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_scalar);
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_input);
inputs.push_back(&input_arg);
TypeProto loop_state_output_tensor;
@ -128,15 +135,17 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// it has to come from here.
bool type_and_shape_required = options.include_dim_values_in_main_graph == false;
if (include_shapes || type_and_shape_required)
loop_state_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
if (include_dim_values || type_and_shape_required) {
mutable_shape = loop_state_output_tensor.mutable_tensor_type()->mutable_shape();
if (!options.scalar_loop_state_value)
mutable_shape->add_dim()->set_dim_value(1);
}
TypeProto* type_proto = include_types || type_and_shape_required ? &loop_state_output_tensor : nullptr;
auto& output_arg = graph.GetOrCreateNodeArg("loop_state_out_1", type_proto);
outputs.push_back(&output_arg);
auto* add = graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
(void)add;
graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
}
// subgraph with multiple inputs and outputs to test variadic behaviour.
@ -152,7 +161,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// inputs must have type information and rank, but dimension can have no value if we're not providing shape info.
concat_input_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = concat_input_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes) {
if (include_dim_values) {
mutable_dim->set_dim_value(2);
if (options.add_bad_shape) {
@ -168,7 +177,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// one output from concatenate of {4} tensor
TypeProto concat_output_tensor;
concat_output_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
if (include_shapes)
if (include_dim_values)
concat_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4);
TypeProto* type_proto = include_types ? &concat_output_tensor : nullptr;
@ -277,13 +286,18 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen
test.AddInput<int64_t>("sequence_lens", sequence_lens_dims, *sequence_lens);
}
test.AddInput<float>("scan_loop_state_in_0", {batch_size, 1}, loop_state_in_0);
std::vector<int64_t> loop_state_shape{batch_size};
if (!options.scalar_loop_state_value) {
loop_state_shape.push_back(1);
}
test.AddInput<float>("scan_loop_state_in_0", loop_state_shape, loop_state_in_0);
std::vector<int64_t> input_shape{batch_size, max_sequence_len, input_size};
test.AddInput<float>("scan_input_0", input_shape, input_0);
test.AddInput<float>("scan_input_1", input_shape, input_1);
test.AddOutput<float>("scan_loop_state_out_0", {batch_size, 1}, loop_state_out_0);
test.AddOutput<float>("scan_loop_state_out_0", loop_state_shape, loop_state_out_0);
std::vector<int64_t> output_shape{batch_size, max_sequence_len, 1};
test.AddOutput<float>("scan_output_0", output_shape, output_0);
@ -353,6 +367,16 @@ TEST(Scan, ShortSequenceOneInBatchOneLoopStateVar_NoShapeInMainGraph_NoTypeAndSh
ShortSequenceOneInBatchOneLoopStateVar(options);
}
TEST(Scan, OnnxScalarLoopState) {
RunOptions options{};
options.include_dim_values_in_main_graph = true;
options.include_types_in_subgraph = false;
options.include_dim_values_in_subgraph = false;
options.scalar_loop_state_value = true;
ShortSequenceOneInBatchOneLoopStateVar(options);
}
// test when there is an operator in the subgraph that uses a value coming from outer scope
TEST(Scan, OuterScopeAccess_NoShapeInMainGraph_TypeAndShapeInSubgraph) {
RunOptions options{};