mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support empty tensor concats in Concat op (#735)
* Concat bug fix * CUDA concat changes
This commit is contained in:
parent
7d47cd39b6
commit
afe3aae29f
4 changed files with 84 additions and 22 deletions
|
|
@ -17,59 +17,83 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
|
|||
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const Tensor& inputs_0 = *tensor_pointer;
|
||||
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
|
||||
const size_t inputs_0_rank = inputs_0_dims.size();
|
||||
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
|
||||
|
||||
auto axis = HandleNegativeAxis(axis_, inputs_0.Shape().NumDimensions());
|
||||
|
||||
// cache num of elements in tensor for later use
|
||||
// as it's expensive to call Size() on TensorShape over and over
|
||||
std::vector<size_t> tensor_num_elements(input_count);
|
||||
// Ensure all of the non concatenated axes match each other
|
||||
for (int index = 1; index < input_count; index++) {
|
||||
size_t num_elements = 1;
|
||||
tensor_pointer = ctx->Input<Tensor>(index);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
auto& data_n = *tensor_pointer;
|
||||
// Ensure all the other axes match
|
||||
auto dimension_count = inputs_0.Shape().NumDimensions();
|
||||
for (int axis_index = 0; axis_index < dimension_count; axis_index++) {
|
||||
auto& inputs_n = *tensor_pointer;
|
||||
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
|
||||
const size_t inputs_n_rank = inputs_n_dims.size();
|
||||
ORT_ENFORCE(inputs_n_rank == inputs_0_rank, "Ranks of input data are different, cannot concatenate them, "
|
||||
"expected rank: ", std::to_string(inputs_0_rank), " got: ", std::to_string(inputs_n_rank));
|
||||
// Ensure all the other (non-concat) axes match
|
||||
for (int axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
|
||||
num_elements *= inputs_n_dims[axis_index];
|
||||
if (axis_index == axis)
|
||||
continue;
|
||||
ORT_RETURN_IF_NOT(data_n.Shape()[axis_index] == inputs_0.Shape()[axis_index], "Non concat axis dimensions must match: Axis ", axis_index, " has mismatched dimensions of ", data_n.Shape()[axis_index], " and ", inputs_0.Shape()[axis_index]);
|
||||
ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index],
|
||||
"Non concat axis dimensions must match: Axis ",
|
||||
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
|
||||
" and ", inputs_0_dims[axis_index]);
|
||||
}
|
||||
tensor_num_elements[index] = num_elements;
|
||||
}
|
||||
|
||||
// Calculate the size of the concatenated axis, and verify all other dimensions match
|
||||
size_t concat_axis_size = 0;
|
||||
for (int index = 0; index < input_count; index++) {
|
||||
tensor_pointer = ctx->Input<Tensor>(index);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
concat_axis_size += tensor_pointer->Shape()[int(axis)];
|
||||
}
|
||||
|
||||
// Calculate the shape of the output tensor
|
||||
std::vector<int64_t> dims;
|
||||
for (int dimension_index = 0; dimension_index < inputs_0.Shape().NumDimensions(); dimension_index++)
|
||||
dims.emplace_back(inputs_0.Shape()[dimension_index]);
|
||||
std::vector<int64_t> dims(inputs_0_rank);
|
||||
size_t num_elements = 1; // cache size of the first input along the way
|
||||
for (int dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) {
|
||||
dims[dimension_index] = inputs_0_dims[dimension_index];
|
||||
num_elements *= inputs_0_dims[dimension_index];
|
||||
}
|
||||
tensor_num_elements[0] = num_elements;
|
||||
dims[axis] = concat_axis_size;
|
||||
TensorShape outputShape(dims);
|
||||
TensorShape output_shape(dims);
|
||||
|
||||
auto& concat_result = *ctx->Output(0, output_shape);
|
||||
p.output_tensor = &concat_result;
|
||||
p.output_num_elements = output_shape.Size();
|
||||
|
||||
// if the output tensor is not going to hold any elements,
|
||||
// there is no need to proceed further
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output
|
||||
p.output_axis_pitch = 1;
|
||||
for (auto i = int64_t(dims.size()); i-- > axis;)
|
||||
for (auto i = int64_t(inputs_0_rank); i-- > axis;)
|
||||
p.output_axis_pitch *= dims[i];
|
||||
|
||||
auto& concat_result = *ctx->Output(0, outputShape);
|
||||
p.output_tensor = &concat_result;
|
||||
|
||||
for (int input_index = 0; input_index < input_count; input_index++) {
|
||||
const Tensor* data_n_ptr = ctx->Input<Tensor>(input_index);
|
||||
ORT_ENFORCE(data_n_ptr != nullptr);
|
||||
auto& data_n = *data_n_ptr;
|
||||
|
||||
ORT_RETURN_IF_NOT(data_n.DataType() == concat_result.DataType());
|
||||
|
||||
// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
|
||||
int64_t input_axis_pitch = 1;
|
||||
for (int i = int(data_n.Shape().NumDimensions()); i-- > axis;)
|
||||
input_axis_pitch *= data_n.Shape()[i];
|
||||
const auto& data_dims = data_n.Shape().GetDims();
|
||||
for (int i = static_cast<int>(inputs_0_rank); i-- > axis;)
|
||||
input_axis_pitch *= data_dims[i];
|
||||
|
||||
p.inputs.push_back({&data_n, input_axis_pitch});
|
||||
p.inputs.push_back({&data_n, tensor_num_elements[input_index], input_axis_pitch});
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -81,15 +105,23 @@ Status Concat::Compute(OpKernelContext* ctx) const {
|
|||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
|
||||
|
||||
// return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
auto is_string_type = ctx->Input<Tensor>(0)->DataType() == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
int64_t output_offset = 0;
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
for (int input_index = 0; input_index < input_count; input_index++) {
|
||||
const auto& prep = p.inputs[input_index];
|
||||
// no data in this tensor - so skip it
|
||||
if (prep.num_elements == 0)
|
||||
continue;
|
||||
auto input_axis_pitch = prep.axis_pitch;
|
||||
const uint8_t* input = static_cast<const uint8_t*>(prep.tensor->DataRaw());
|
||||
auto input_size = prep.tensor->Shape().Size();
|
||||
|
||||
auto input_size = prep.num_elements;
|
||||
|
||||
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
|
||||
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@ class ConcatBase {
|
|||
struct Prepare {
|
||||
struct InputInfo {
|
||||
const Tensor* tensor;
|
||||
size_t num_elements;
|
||||
int64_t axis_pitch;
|
||||
};
|
||||
std::vector<InputInfo> inputs;
|
||||
size_t output_num_elements;
|
||||
int64_t output_axis_pitch;
|
||||
Tensor* output_tensor;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -20,11 +20,17 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
|
|||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
|
||||
|
||||
// Return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
int64_t output_offset = 0;
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
for (int input_index = 0; input_index < input_count; input_index++) {
|
||||
const auto& prep = p.inputs[input_index];
|
||||
|
||||
// No data in this tensor - so skip it
|
||||
if (prep.num_elements == 0)
|
||||
continue;
|
||||
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync(
|
||||
static_cast<uint8_t*>(p.output_tensor->MutableDataRaw()) + output_offset * element_bytes,
|
||||
|
|
@ -32,7 +38,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
|
|||
prep.tensor->DataRaw(),
|
||||
prep.axis_pitch * element_bytes,
|
||||
prep.axis_pitch * element_bytes,
|
||||
prep.tensor->Shape().Size() / prep.axis_pitch,
|
||||
prep.num_elements / prep.axis_pitch,
|
||||
cudaMemcpyDeviceToDevice));
|
||||
|
||||
output_offset += prep.axis_pitch;
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ TEST(MathOpTest, Concat1D_int32_negative_axis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Concat1D) {
|
||||
TEST(MathOpTest, Concat1D_1) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{0});
|
||||
|
||||
|
|
@ -51,6 +51,17 @@ TEST(MathOpTest, Concat1D) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Concat1D_2) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{0});
|
||||
|
||||
test.AddInput<float>("input1", {1}, {1.0f});
|
||||
test.AddInput<float>("input2", {2}, {2.0f, 3.0f});
|
||||
test.AddInput<float>("input3", {0}, {});
|
||||
test.AddOutput<float>("concat_result", {3}, {1.0f, 2.0f, 3.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Concat2D_1) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{0});
|
||||
|
|
@ -82,6 +93,17 @@ TEST(MathOpTest, Concat2D_2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Concat2D_3) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{1});
|
||||
|
||||
test.AddInput<float>("input1", {1, 0}, {});
|
||||
test.AddInput<float>("input2", {1, 0}, {});
|
||||
test.AddInput<float>("input3", {1, 0}, {});
|
||||
test.AddOutput<float>("concat_result", {1, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Concat3D_1) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{0});
|
||||
|
|
|
|||
Loading…
Reference in a new issue