mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
This commit is contained in:
parent
b2247ece25
commit
08eeb8763d
3 changed files with 105 additions and 44 deletions
|
|
@ -35,66 +35,117 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
|
|||
const std::vector<const Tensor*>& input_tensors,
|
||||
Prepare& p) const {
|
||||
int input_count = static_cast<int>(input_tensors.size());
|
||||
|
||||
// Must have atleast one input to concat
|
||||
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
|
||||
|
||||
const Tensor* tensor_pointer = input_tensors[0];
|
||||
if (tensor_pointer == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
std::vector<int64_t> reference_dims;
|
||||
size_t reference_rank = 0;
|
||||
|
||||
const Tensor& inputs_0 = *tensor_pointer;
|
||||
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
|
||||
const size_t inputs_0_rank = inputs_0_dims.size();
|
||||
int reference_tensor_index = 0;
|
||||
|
||||
std::vector<int64_t> input_tensor_sizes;
|
||||
input_tensor_sizes.reserve(input_count);
|
||||
|
||||
bool all_inputs_are_empty = true;
|
||||
|
||||
for (int index = 0; index < input_count; ++index) {
|
||||
const auto* input = input_tensors[index];
|
||||
ORT_ENFORCE(input != nullptr, "input count mismatch");
|
||||
|
||||
// find the first tensor that isn't empty
|
||||
// to be used as a reference for all
|
||||
// downstream shape/rank validations of other inputs
|
||||
const auto& shape = input->Shape();
|
||||
const auto num_elements = shape.Size();
|
||||
if (num_elements > 0) {
|
||||
reference_dims = shape.GetDims();
|
||||
reference_rank = reference_dims.size();
|
||||
reference_tensor_index = index;
|
||||
input_tensor_sizes.push_back(num_elements);
|
||||
all_inputs_are_empty = false;
|
||||
break;
|
||||
} else {
|
||||
input_tensor_sizes.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (all_inputs_are_empty) {
|
||||
// Reference dim and reference rank can just come from the first input
|
||||
// No shape/rank validations will be done (as all inputs are empty).
|
||||
// But the rest of the execution flow (filling in the Prepare instance - p)
|
||||
// can use this info.
|
||||
reference_dims = input_tensors[0]->Shape().GetDims();
|
||||
reference_rank = reference_dims.size();
|
||||
}
|
||||
|
||||
// Cannot concatenate scalars (but they can be stacked)
|
||||
if (!is_stack_)
|
||||
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
|
||||
ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars");
|
||||
|
||||
// Handle and fix negative axis
|
||||
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
|
||||
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_ ? inputs_0_rank : inputs_0_rank + 1));
|
||||
|
||||
// Note if input tensor is empty for later use (it's expensive to call Size() on TensorShape)
|
||||
std::vector<int64_t> input_tensor_sizes(input_count);
|
||||
// Assign the number of values in the first input tensor
|
||||
input_tensor_sizes[0] = inputs_0.Shape().Size();
|
||||
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_
|
||||
? reference_rank
|
||||
: reference_rank + 1));
|
||||
|
||||
// Ensure all of the non concatenated axes match each other
|
||||
for (int index = 1; index < input_count; index++) {
|
||||
tensor_pointer = input_tensors[index];
|
||||
if (tensor_pointer == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
auto& inputs_n = *tensor_pointer;
|
||||
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
|
||||
const size_t inputs_n_rank = inputs_n_dims.size();
|
||||
ORT_ENFORCE(inputs_n_rank == inputs_0_rank,
|
||||
"Ranks of input data are different, cannot concatenate them. expected rank: ",
|
||||
inputs_0_rank, " got: ", inputs_n_rank);
|
||||
// Ensure all the other (non-concat) axes match
|
||||
int64_t tensor_size = 1;
|
||||
for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
|
||||
auto dim_value = inputs_n_dims[axis_index];
|
||||
tensor_size *= dim_value;
|
||||
for (int index = reference_tensor_index + 1; index < input_count; index++) {
|
||||
const auto* input = input_tensors[index];
|
||||
ORT_ENFORCE(input != nullptr, "input count mismatch");
|
||||
const auto& input_shape = input->Shape();
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
|
||||
// In 'concat' mode, the axis to be concatenated may be different
|
||||
// But in 'stack' mode, all input shapes must be the same and must be validated
|
||||
if (!is_stack_ && axis_index == p.axis)
|
||||
continue;
|
||||
// Skip shape/rank validation for inputs that are empty.
|
||||
// The ONNX spec states that all dim values along axes not concatentated on
|
||||
// need to be the same for all inputs (empty inputs are not explicitly exempted).
|
||||
// The model in GH issue 8020 has a bunch of Loop nodes all feeding into
|
||||
// the 'Concat' node and one of these Loops tend to have an iteration
|
||||
// count of 0 for some inputs. If the iteration count for a Loop is zero,
|
||||
// we don't execute its subgraph (since the outputs are going to be empty anyway)
|
||||
// and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape
|
||||
// to "compose" the shape for these empty tensor(s).
|
||||
// If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0'
|
||||
// in that position and due to the "lossy" nature of this process, the inputs' shape
|
||||
// validation for such empty inputs fail and hence we skip these validations for all
|
||||
// empty inputs.
|
||||
// This isn't too bad as we will never use empty inputs while concatenating anyway.
|
||||
// We just loosen this check to unblock model in GH issue 8020 to complete processing.
|
||||
if (input_shape.Size() == 0) {
|
||||
input_tensor_sizes.push_back(0);
|
||||
} else {
|
||||
const size_t input_rank = input_dims.size();
|
||||
|
||||
ORT_RETURN_IF_NOT(dim_value == inputs_0_dims[axis_index],
|
||||
"Non concat axis dimensions must match: Axis ",
|
||||
axis_index, " has mismatched dimensions of ", dim_value,
|
||||
" and ", inputs_0_dims[axis_index]);
|
||||
ORT_ENFORCE(input_rank == reference_rank,
|
||||
"Ranks of input data are different, cannot concatenate them. expected rank: ",
|
||||
reference_rank, " got: ", input_rank);
|
||||
|
||||
// Ensure all the other (non-concat) axes match
|
||||
int64_t tensor_size = 1;
|
||||
for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) {
|
||||
auto dim_value = input_dims[axis_index];
|
||||
tensor_size *= dim_value;
|
||||
|
||||
// In 'concat' mode, the axis to be concatenated may be different
|
||||
// But in 'stack' mode, all input shapes must be the same and must be validated
|
||||
if (!is_stack_ && axis_index == p.axis)
|
||||
continue;
|
||||
|
||||
ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index],
|
||||
"Non concat axis dimensions must match: Axis ",
|
||||
axis_index, " has mismatched dimensions of ", dim_value,
|
||||
" and ", reference_dims[axis_index]);
|
||||
}
|
||||
|
||||
input_tensor_sizes.push_back(tensor_size); //assign the computed size of the input tensor
|
||||
}
|
||||
|
||||
input_tensor_sizes[index] = tensor_size; //assign the computed size of the input tensor
|
||||
}
|
||||
|
||||
// Calculate the shape of the output tensor
|
||||
std::vector<int64_t> output_dims = inputs_0_dims;
|
||||
// 'Concat' mode
|
||||
if (!is_stack_) {
|
||||
// While concating, the rank of the output is the same as the input rank(s)
|
||||
std::vector<int64_t> output_dims = reference_dims;
|
||||
|
||||
if (!is_stack_) { // 'Concat' mode
|
||||
// While concatenating, the rank of the output is the same as the input rank(s)
|
||||
|
||||
// Calculate the size of the concatenated axis
|
||||
size_t concat_axis_size = 0;
|
||||
|
|
@ -126,7 +177,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
|
|||
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
|
||||
// Can handle stacking as well.
|
||||
p.output_axis_pitch = 1;
|
||||
auto output_rank = !is_stack_ ? inputs_0_rank : inputs_0_rank + 1;
|
||||
auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1;
|
||||
for (size_t i = output_rank; i-- > p.axis;) {
|
||||
p.output_axis_pitch *= output_dims[i];
|
||||
}
|
||||
|
|
@ -146,7 +197,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
|
|||
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
|
||||
int64_t input_axis_pitch = 1;
|
||||
const auto& data_dims = data_n.Shape().GetDims();
|
||||
for (size_t i = inputs_0_rank; i-- > p.axis;) {
|
||||
for (size_t i = reference_rank; i-- > p.axis;) {
|
||||
input_axis_pitch *= data_dims[i];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,11 @@ Status GatherND::Compute(OpKernelContext* context) const {
|
|||
|
||||
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
|
||||
|
||||
// Bail out early in case the output is going to be empty
|
||||
if (output_tensor->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Prepare p;
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
if (input_tensor->IsDataTypeString()) {
|
||||
|
|
|
|||
|
|
@ -193,6 +193,11 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
auto output_tensor = context->Output(0, TensorShape(shape));
|
||||
|
||||
// Bail out early in case the output is going to be empty
|
||||
if (output_tensor->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compute
|
||||
int64_t num_slices;
|
||||
int64_t slice_size;
|
||||
|
|
|
|||
Loading…
Reference in a new issue