Loosen validation checks in Concat to unblock execution of model in #8020 (#8080)

This commit is contained in:
Hariharan Seshadri 2021-06-18 11:14:36 -07:00 committed by GitHub
parent b2247ece25
commit 08eeb8763d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 44 deletions

View file

@ -35,66 +35,117 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
const std::vector<const Tensor*>& input_tensors,
Prepare& p) const {
int input_count = static_cast<int>(input_tensors.size());
// Must have atleast one input to concat
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
const Tensor* tensor_pointer = input_tensors[0];
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
std::vector<int64_t> reference_dims;
size_t reference_rank = 0;
const Tensor& inputs_0 = *tensor_pointer;
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
const size_t inputs_0_rank = inputs_0_dims.size();
int reference_tensor_index = 0;
std::vector<int64_t> input_tensor_sizes;
input_tensor_sizes.reserve(input_count);
bool all_inputs_are_empty = true;
for (int index = 0; index < input_count; ++index) {
const auto* input = input_tensors[index];
ORT_ENFORCE(input != nullptr, "input count mismatch");
// find the first tensor that isn't empty
// to be used as a reference for all
// downstream shape/rank validations of other inputs
const auto& shape = input->Shape();
const auto num_elements = shape.Size();
if (num_elements > 0) {
reference_dims = shape.GetDims();
reference_rank = reference_dims.size();
reference_tensor_index = index;
input_tensor_sizes.push_back(num_elements);
all_inputs_are_empty = false;
break;
} else {
input_tensor_sizes.push_back(0);
}
}
if (all_inputs_are_empty) {
// Reference dim and reference rank can just come from the first input
// No shape/rank validations will be done (as all inputs are empty).
// But the rest of the execution flow (filling in the Prepare instance - p)
// can use this info.
reference_dims = input_tensors[0]->Shape().GetDims();
reference_rank = reference_dims.size();
}
// Cannot concatenate scalars (but they can be stacked)
if (!is_stack_)
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars");
// Handle and fix negative axis
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_ ? inputs_0_rank : inputs_0_rank + 1));
// Note if input tensor is empty for later use (it's expensive to call Size() on TensorShape)
std::vector<int64_t> input_tensor_sizes(input_count);
// Assign the number of values in the first input tensor
input_tensor_sizes[0] = inputs_0.Shape().Size();
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_
? reference_rank
: reference_rank + 1));
// Ensure all of the non concatenated axes match each other
for (int index = 1; index < input_count; index++) {
tensor_pointer = input_tensors[index];
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
auto& inputs_n = *tensor_pointer;
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
const size_t inputs_n_rank = inputs_n_dims.size();
ORT_ENFORCE(inputs_n_rank == inputs_0_rank,
"Ranks of input data are different, cannot concatenate them. expected rank: ",
inputs_0_rank, " got: ", inputs_n_rank);
// Ensure all the other (non-concat) axes match
int64_t tensor_size = 1;
for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
auto dim_value = inputs_n_dims[axis_index];
tensor_size *= dim_value;
for (int index = reference_tensor_index + 1; index < input_count; index++) {
const auto* input = input_tensors[index];
ORT_ENFORCE(input != nullptr, "input count mismatch");
const auto& input_shape = input->Shape();
const auto& input_dims = input_shape.GetDims();
// In 'concat' mode, the axis to be concatenated may be different
// But in 'stack' mode, all input shapes must be the same and must be validated
if (!is_stack_ && axis_index == p.axis)
continue;
// Skip shape/rank validation for inputs that are empty.
// The ONNX spec states that all dim values along axes not concatentated on
// need to be the same for all inputs (empty inputs are not explicitly exempted).
// The model in GH issue 8020 has a bunch of Loop nodes all feeding into
// the 'Concat' node and one of these Loops tend to have an iteration
// count of 0 for some inputs. If the iteration count for a Loop is zero,
// we don't execute its subgraph (since the outputs are going to be empty anyway)
// and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape
// to "compose" the shape for these empty tensor(s).
// If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0'
// in that position and due to the "lossy" nature of this process, the inputs' shape
// validation for such empty inputs fail and hence we skip these validations for all
// empty inputs.
// This isn't too bad as we will never use empty inputs while concatenating anyway.
// We just loosen this check to unblock model in GH issue 8020 to complete processing.
if (input_shape.Size() == 0) {
input_tensor_sizes.push_back(0);
} else {
const size_t input_rank = input_dims.size();
ORT_RETURN_IF_NOT(dim_value == inputs_0_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", dim_value,
" and ", inputs_0_dims[axis_index]);
ORT_ENFORCE(input_rank == reference_rank,
"Ranks of input data are different, cannot concatenate them. expected rank: ",
reference_rank, " got: ", input_rank);
// Ensure all the other (non-concat) axes match
int64_t tensor_size = 1;
for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) {
auto dim_value = input_dims[axis_index];
tensor_size *= dim_value;
// In 'concat' mode, the axis to be concatenated may be different
// But in 'stack' mode, all input shapes must be the same and must be validated
if (!is_stack_ && axis_index == p.axis)
continue;
ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", dim_value,
" and ", reference_dims[axis_index]);
}
input_tensor_sizes.push_back(tensor_size); //assign the computed size of the input tensor
}
input_tensor_sizes[index] = tensor_size; //assign the computed size of the input tensor
}
// Calculate the shape of the output tensor
std::vector<int64_t> output_dims = inputs_0_dims;
// 'Concat' mode
if (!is_stack_) {
// While concating, the rank of the output is the same as the input rank(s)
std::vector<int64_t> output_dims = reference_dims;
if (!is_stack_) { // 'Concat' mode
// While concatenating, the rank of the output is the same as the input rank(s)
// Calculate the size of the concatenated axis
size_t concat_axis_size = 0;
@ -126,7 +177,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
// Can handle stacking as well.
p.output_axis_pitch = 1;
auto output_rank = !is_stack_ ? inputs_0_rank : inputs_0_rank + 1;
auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1;
for (size_t i = output_rank; i-- > p.axis;) {
p.output_axis_pitch *= output_dims[i];
}
@ -146,7 +197,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
int64_t input_axis_pitch = 1;
const auto& data_dims = data_n.Shape().GetDims();
for (size_t i = inputs_0_rank; i-- > p.axis;) {
for (size_t i = reference_rank; i-- > p.axis;) {
input_axis_pitch *= data_dims[i];
}

View file

@ -148,6 +148,11 @@ Status GatherND::Compute(OpKernelContext* context) const {
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
// Bail out early in case the output is going to be empty
if (output_tensor->Shape().Size() == 0) {
return Status::OK();
}
Prepare p;
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
if (input_tensor->IsDataTypeString()) {

View file

@ -193,6 +193,11 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
auto output_tensor = context->Output(0, TensorShape(shape));
// Bail out early in case the output is going to be empty
if (output_tensor->Shape().Size() == 0) {
return Status::OK();
}
// Compute
int64_t num_slices;
int64_t slice_size;