diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc index 96531e71d3..7e346c20e2 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.cc +++ b/onnxruntime/core/providers/cpu/tensor/concat.cc @@ -35,66 +35,117 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const std::vector& input_tensors, Prepare& p) const { int input_count = static_cast(input_tensors.size()); + // Must have atleast one input to concat ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs"); - const Tensor* tensor_pointer = input_tensors[0]; - if (tensor_pointer == nullptr) - return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + std::vector reference_dims; + size_t reference_rank = 0; - const Tensor& inputs_0 = *tensor_pointer; - const auto& inputs_0_dims = inputs_0.Shape().GetDims(); - const size_t inputs_0_rank = inputs_0_dims.size(); + int reference_tensor_index = 0; + + std::vector input_tensor_sizes; + input_tensor_sizes.reserve(input_count); + + bool all_inputs_are_empty = true; + + for (int index = 0; index < input_count; ++index) { + const auto* input = input_tensors[index]; + ORT_ENFORCE(input != nullptr, "input count mismatch"); + + // find the first tensor that isn't empty + // to be used as a reference for all + // downstream shape/rank validations of other inputs + const auto& shape = input->Shape(); + const auto num_elements = shape.Size(); + if (num_elements > 0) { + reference_dims = shape.GetDims(); + reference_rank = reference_dims.size(); + reference_tensor_index = index; + input_tensor_sizes.push_back(num_elements); + all_inputs_are_empty = false; + break; + } else { + input_tensor_sizes.push_back(0); + } + } + + if (all_inputs_are_empty) { + // Reference dim and reference rank can just come from the first input + // No shape/rank validations will be done (as all inputs are empty). + // But the rest of the execution flow (filling in the Prepare instance - p) + // can use this info. + reference_dims = input_tensors[0]->Shape().GetDims(); + reference_rank = reference_dims.size(); + } // Cannot concatenate scalars (but they can be stacked) if (!is_stack_) - ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars"); + ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars"); // Handle and fix negative axis // In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank) - p.axis = static_cast(HandleNegativeAxis(axis_, !is_stack_ ? inputs_0_rank : inputs_0_rank + 1)); - - // Note if input tensor is empty for later use (it's expensive to call Size() on TensorShape) - std::vector input_tensor_sizes(input_count); - // Assign the number of values in the first input tensor - input_tensor_sizes[0] = inputs_0.Shape().Size(); + p.axis = static_cast(HandleNegativeAxis(axis_, !is_stack_ + ? reference_rank + : reference_rank + 1)); // Ensure all of the non concatenated axes match each other - for (int index = 1; index < input_count; index++) { - tensor_pointer = input_tensors[index]; - if (tensor_pointer == nullptr) - return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - auto& inputs_n = *tensor_pointer; - const auto& inputs_n_dims = inputs_n.Shape().GetDims(); - const size_t inputs_n_rank = inputs_n_dims.size(); - ORT_ENFORCE(inputs_n_rank == inputs_0_rank, - "Ranks of input data are different, cannot concatenate them. expected rank: ", - inputs_0_rank, " got: ", inputs_n_rank); - // Ensure all the other (non-concat) axes match - int64_t tensor_size = 1; - for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) { - auto dim_value = inputs_n_dims[axis_index]; - tensor_size *= dim_value; + for (int index = reference_tensor_index + 1; index < input_count; index++) { + const auto* input = input_tensors[index]; + ORT_ENFORCE(input != nullptr, "input count mismatch"); + const auto& input_shape = input->Shape(); + const auto& input_dims = input_shape.GetDims(); - // In 'concat' mode, the axis to be concatenated may be different - // But in 'stack' mode, all input shapes must be the same and must be validated - if (!is_stack_ && axis_index == p.axis) - continue; + // Skip shape/rank validation for inputs that are empty. + // The ONNX spec states that all dim values along axes not concatentated on + // need to be the same for all inputs (empty inputs are not explicitly exempted). + // The model in GH issue 8020 has a bunch of Loop nodes all feeding into + // the 'Concat' node and one of these Loops tend to have an iteration + // count of 0 for some inputs. If the iteration count for a Loop is zero, + // we don't execute its subgraph (since the outputs are going to be empty anyway) + // and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape + // to "compose" the shape for these empty tensor(s). + // If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0' + // in that position and due to the "lossy" nature of this process, the inputs' shape + // validation for such empty inputs fail and hence we skip these validations for all + // empty inputs. + // This isn't too bad as we will never use empty inputs while concatenating anyway. + // We just loosen this check to unblock model in GH issue 8020 to complete processing. + if (input_shape.Size() == 0) { + input_tensor_sizes.push_back(0); + } else { + const size_t input_rank = input_dims.size(); - ORT_RETURN_IF_NOT(dim_value == inputs_0_dims[axis_index], - "Non concat axis dimensions must match: Axis ", - axis_index, " has mismatched dimensions of ", dim_value, - " and ", inputs_0_dims[axis_index]); + ORT_ENFORCE(input_rank == reference_rank, + "Ranks of input data are different, cannot concatenate them. expected rank: ", + reference_rank, " got: ", input_rank); + + // Ensure all the other (non-concat) axes match + int64_t tensor_size = 1; + for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) { + auto dim_value = input_dims[axis_index]; + tensor_size *= dim_value; + + // In 'concat' mode, the axis to be concatenated may be different + // But in 'stack' mode, all input shapes must be the same and must be validated + if (!is_stack_ && axis_index == p.axis) + continue; + + ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index], + "Non concat axis dimensions must match: Axis ", + axis_index, " has mismatched dimensions of ", dim_value, + " and ", reference_dims[axis_index]); + } + + input_tensor_sizes.push_back(tensor_size); //assign the computed size of the input tensor } - - input_tensor_sizes[index] = tensor_size; //assign the computed size of the input tensor } // Calculate the shape of the output tensor - std::vector output_dims = inputs_0_dims; - // 'Concat' mode - if (!is_stack_) { - // While concating, the rank of the output is the same as the input rank(s) + std::vector output_dims = reference_dims; + + if (!is_stack_) { // 'Concat' mode + // While concatenating, the rank of the output is the same as the input rank(s) // Calculate the size of the concatenated axis size_t concat_axis_size = 0; @@ -126,7 +177,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, // The output_axis_pitch is the number of elements to add to move to the next split axis in the output. // Can handle stacking as well. p.output_axis_pitch = 1; - auto output_rank = !is_stack_ ? inputs_0_rank : inputs_0_rank + 1; + auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1; for (size_t i = output_rank; i-- > p.axis;) { p.output_axis_pitch *= output_dims[i]; } @@ -146,7 +197,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, // in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating. int64_t input_axis_pitch = 1; const auto& data_dims = data_n.Shape().GetDims(); - for (size_t i = inputs_0_rank; i-- > p.axis;) { + for (size_t i = reference_rank; i-- > p.axis;) { input_axis_pitch *= data_dims[i]; } diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index f3592357e0..a1dd83dce6 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -148,6 +148,11 @@ Status GatherND::Compute(OpKernelContext* context) const { auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); + // Bail out early in case the output is going to be empty + if (output_tensor->Shape().Size() == 0) { + return Status::OK(); + } + Prepare p; concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); if (input_tensor->IsDataTypeString()) { diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 3cd7b7dc80..64e1f958b2 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -193,6 +193,11 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { auto output_tensor = context->Output(0, TensorShape(shape)); + // Bail out early in case the output is going to be empty + if (output_tensor->Shape().Size() == 0) { + return Status::OK(); + } + // Compute int64_t num_slices; int64_t slice_size;