From afe3aae29fb0d24322824276bb9e76c8d838e8b6 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 2 Apr 2019 11:32:42 -0700 Subject: [PATCH] Support empty tensor concats in Concat op (#735) * Concat bug fix * CUDA concat changes --- .../core/providers/cpu/tensor/concat.cc | 70 ++++++++++++++----- .../core/providers/cpu/tensor/concat.h | 2 + .../core/providers/cuda/tensor/concat.cc | 10 ++- .../providers/cpu/tensor/concat_op_test.cc | 24 ++++++- 4 files changed, 84 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc index 498d7ddda0..91721bc0ac 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.cc +++ b/onnxruntime/core/providers/cpu/tensor/concat.cc @@ -17,59 +17,83 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep const Tensor* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& inputs_0 = *tensor_pointer; + const auto& inputs_0_dims = inputs_0.Shape().GetDims(); + const size_t inputs_0_rank = inputs_0_dims.size(); + ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars"); auto axis = HandleNegativeAxis(axis_, inputs_0.Shape().NumDimensions()); + // cache num of elements in tensor for later use + // as it's expensive to call Size() on TensorShape over and over + std::vector tensor_num_elements(input_count); // Ensure all of the non concatenated axes match each other for (int index = 1; index < input_count; index++) { + size_t num_elements = 1; tensor_pointer = ctx->Input(index); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - auto& data_n = *tensor_pointer; - // Ensure all the other axes match - auto dimension_count = inputs_0.Shape().NumDimensions(); - for (int axis_index = 0; axis_index < dimension_count; axis_index++) { + auto& inputs_n = *tensor_pointer; + const auto& inputs_n_dims = inputs_n.Shape().GetDims(); + const size_t inputs_n_rank = inputs_n_dims.size(); + ORT_ENFORCE(inputs_n_rank == inputs_0_rank, "Ranks of input data are different, cannot concatenate them, " + "expected rank: ", std::to_string(inputs_0_rank), " got: ", std::to_string(inputs_n_rank)); + // Ensure all the other (non-concat) axes match + for (int axis_index = 0; axis_index < inputs_0_rank; ++axis_index) { + num_elements *= inputs_n_dims[axis_index]; if (axis_index == axis) continue; - ORT_RETURN_IF_NOT(data_n.Shape()[axis_index] == inputs_0.Shape()[axis_index], "Non concat axis dimensions must match: Axis ", axis_index, " has mismatched dimensions of ", data_n.Shape()[axis_index], " and ", inputs_0.Shape()[axis_index]); + ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index], + "Non concat axis dimensions must match: Axis ", + axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index], + " and ", inputs_0_dims[axis_index]); } + tensor_num_elements[index] = num_elements; } // Calculate the size of the concatenated axis, and verify all other dimensions match size_t concat_axis_size = 0; for (int index = 0; index < input_count; index++) { tensor_pointer = ctx->Input(index); - if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); concat_axis_size += tensor_pointer->Shape()[int(axis)]; } // Calculate the shape of the output tensor - std::vector dims; - for (int dimension_index = 0; dimension_index < inputs_0.Shape().NumDimensions(); dimension_index++) - dims.emplace_back(inputs_0.Shape()[dimension_index]); + std::vector dims(inputs_0_rank); + size_t num_elements = 1; // cache size of the first input along the way + for (int dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) { + dims[dimension_index] = inputs_0_dims[dimension_index]; + num_elements *= inputs_0_dims[dimension_index]; + } + tensor_num_elements[0] = num_elements; dims[axis] = concat_axis_size; - TensorShape outputShape(dims); + TensorShape output_shape(dims); + + auto& concat_result = *ctx->Output(0, output_shape); + p.output_tensor = &concat_result; + p.output_num_elements = output_shape.Size(); + // if the output tensor is not going to hold any elements, + // there is no need to proceed further + if (p.output_num_elements == 0) + return Status::OK(); + // The output_axis_pitch is the number of elements to add to move to the next split axis in the output p.output_axis_pitch = 1; - for (auto i = int64_t(dims.size()); i-- > axis;) + for (auto i = int64_t(inputs_0_rank); i-- > axis;) p.output_axis_pitch *= dims[i]; - auto& concat_result = *ctx->Output(0, outputShape); - p.output_tensor = &concat_result; - for (int input_index = 0; input_index < input_count; input_index++) { const Tensor* data_n_ptr = ctx->Input(input_index); - ORT_ENFORCE(data_n_ptr != nullptr); auto& data_n = *data_n_ptr; ORT_RETURN_IF_NOT(data_n.DataType() == concat_result.DataType()); // The input_axis_pitch is the number of elements to add to move to the next split axis in the input int64_t input_axis_pitch = 1; - for (int i = int(data_n.Shape().NumDimensions()); i-- > axis;) - input_axis_pitch *= data_n.Shape()[i]; + const auto& data_dims = data_n.Shape().GetDims(); + for (int i = static_cast(inputs_0_rank); i-- > axis;) + input_axis_pitch *= data_dims[i]; - p.inputs.push_back({&data_n, input_axis_pitch}); + p.inputs.push_back({&data_n, tensor_num_elements[input_index], input_axis_pitch}); } return Status::OK(); @@ -81,15 +105,23 @@ Status Concat::Compute(OpKernelContext* ctx) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p)); + // return at this point if output tensor is going to be empty + if (p.output_num_elements == 0) + return Status::OK(); + auto is_string_type = ctx->Input(0)->DataType() == DataTypeImpl::GetType(); int64_t output_offset = 0; auto element_bytes = p.output_tensor->DataType()->Size(); for (int input_index = 0; input_index < input_count; input_index++) { const auto& prep = p.inputs[input_index]; + // no data in this tensor - so skip it + if (prep.num_elements == 0) + continue; auto input_axis_pitch = prep.axis_pitch; const uint8_t* input = static_cast(prep.tensor->DataRaw()); - auto input_size = prep.tensor->Shape().Size(); + + auto input_size = prep.num_elements; // Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch' uint8_t* output = static_cast(p.output_tensor->MutableDataRaw()); diff --git a/onnxruntime/core/providers/cpu/tensor/concat.h b/onnxruntime/core/providers/cpu/tensor/concat.h index 6493296501..def1e68cc0 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.h +++ b/onnxruntime/core/providers/cpu/tensor/concat.h @@ -21,9 +21,11 @@ class ConcatBase { struct Prepare { struct InputInfo { const Tensor* tensor; + size_t num_elements; int64_t axis_pitch; }; std::vector inputs; + size_t output_num_elements; int64_t output_axis_pitch; Tensor* output_tensor; }; diff --git a/onnxruntime/core/providers/cuda/tensor/concat.cc b/onnxruntime/core/providers/cuda/tensor/concat.cc index 27f7081de3..a4aeb51785 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat.cc +++ b/onnxruntime/core/providers/cuda/tensor/concat.cc @@ -20,11 +20,17 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p)); + // Return at this point if output tensor is going to be empty + if (p.output_num_elements == 0) + return Status::OK(); + int64_t output_offset = 0; auto element_bytes = p.output_tensor->DataType()->Size(); for (int input_index = 0; input_index < input_count; input_index++) { const auto& prep = p.inputs[input_index]; - + // No data in this tensor - so skip it + if (prep.num_elements == 0) + continue; // Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch' CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync( static_cast(p.output_tensor->MutableDataRaw()) + output_offset * element_bytes, @@ -32,7 +38,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { prep.tensor->DataRaw(), prep.axis_pitch * element_bytes, prep.axis_pitch * element_bytes, - prep.tensor->Shape().Size() / prep.axis_pitch, + prep.num_elements / prep.axis_pitch, cudaMemcpyDeviceToDevice)); output_offset += prep.axis_pitch; diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index b727db53ae..1c6d3660f1 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -40,7 +40,7 @@ TEST(MathOpTest, Concat1D_int32_negative_axis) { test.Run(); } -TEST(MathOpTest, Concat1D) { +TEST(MathOpTest, Concat1D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); @@ -51,6 +51,17 @@ TEST(MathOpTest, Concat1D) { test.Run(); } +TEST(MathOpTest, Concat1D_2) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1.0f}); + test.AddInput("input2", {2}, {2.0f, 3.0f}); + test.AddInput("input3", {0}, {}); + test.AddOutput("concat_result", {3}, {1.0f, 2.0f, 3.0f}); + test.Run(); +} + TEST(MathOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); @@ -82,6 +93,17 @@ TEST(MathOpTest, Concat2D_2) { test.Run(); } +TEST(MathOpTest, Concat2D_3) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {1, 0}, {}); + test.AddInput("input2", {1, 0}, {}); + test.AddInput("input3", {1, 0}, {}); + test.AddOutput("concat_result", {1, 0}, {}); + test.Run(); +} + TEST(MathOpTest, Concat3D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0});