Support empty tensor concats in Concat op (#735)

* Concat bug fix

* CUDA concat changes
This commit is contained in:
Hariharan Seshadri 2019-04-02 11:32:42 -07:00 committed by GitHub
parent 7d47cd39b6
commit afe3aae29f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 22 deletions

View file

@ -17,59 +17,83 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor& inputs_0 = *tensor_pointer;
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
const size_t inputs_0_rank = inputs_0_dims.size();
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
auto axis = HandleNegativeAxis(axis_, inputs_0.Shape().NumDimensions());
// cache num of elements in tensor for later use
// as it's expensive to call Size() on TensorShape over and over
std::vector<size_t> tensor_num_elements(input_count);
// Ensure all of the non concatenated axes match each other
for (int index = 1; index < input_count; index++) {
size_t num_elements = 1;
tensor_pointer = ctx->Input<Tensor>(index);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
auto& data_n = *tensor_pointer;
// Ensure all the other axes match
auto dimension_count = inputs_0.Shape().NumDimensions();
for (int axis_index = 0; axis_index < dimension_count; axis_index++) {
auto& inputs_n = *tensor_pointer;
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
const size_t inputs_n_rank = inputs_n_dims.size();
ORT_ENFORCE(inputs_n_rank == inputs_0_rank, "Ranks of input data are different, cannot concatenate them, "
"expected rank: ", std::to_string(inputs_0_rank), " got: ", std::to_string(inputs_n_rank));
// Ensure all the other (non-concat) axes match
for (int axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
num_elements *= inputs_n_dims[axis_index];
if (axis_index == axis)
continue;
ORT_RETURN_IF_NOT(data_n.Shape()[axis_index] == inputs_0.Shape()[axis_index], "Non concat axis dimensions must match: Axis ", axis_index, " has mismatched dimensions of ", data_n.Shape()[axis_index], " and ", inputs_0.Shape()[axis_index]);
ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
" and ", inputs_0_dims[axis_index]);
}
tensor_num_elements[index] = num_elements;
}
// Calculate the size of the concatenated axis, and verify all other dimensions match
size_t concat_axis_size = 0;
for (int index = 0; index < input_count; index++) {
tensor_pointer = ctx->Input<Tensor>(index);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
concat_axis_size += tensor_pointer->Shape()[int(axis)];
}
// Calculate the shape of the output tensor
std::vector<int64_t> dims;
for (int dimension_index = 0; dimension_index < inputs_0.Shape().NumDimensions(); dimension_index++)
dims.emplace_back(inputs_0.Shape()[dimension_index]);
std::vector<int64_t> dims(inputs_0_rank);
size_t num_elements = 1; // cache size of the first input along the way
for (int dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) {
dims[dimension_index] = inputs_0_dims[dimension_index];
num_elements *= inputs_0_dims[dimension_index];
}
tensor_num_elements[0] = num_elements;
dims[axis] = concat_axis_size;
TensorShape outputShape(dims);
TensorShape output_shape(dims);
auto& concat_result = *ctx->Output(0, output_shape);
p.output_tensor = &concat_result;
p.output_num_elements = output_shape.Size();
// if the output tensor is not going to hold any elements,
// there is no need to proceed further
if (p.output_num_elements == 0)
return Status::OK();
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output
p.output_axis_pitch = 1;
for (auto i = int64_t(dims.size()); i-- > axis;)
for (auto i = int64_t(inputs_0_rank); i-- > axis;)
p.output_axis_pitch *= dims[i];
auto& concat_result = *ctx->Output(0, outputShape);
p.output_tensor = &concat_result;
for (int input_index = 0; input_index < input_count; input_index++) {
const Tensor* data_n_ptr = ctx->Input<Tensor>(input_index);
ORT_ENFORCE(data_n_ptr != nullptr);
auto& data_n = *data_n_ptr;
ORT_RETURN_IF_NOT(data_n.DataType() == concat_result.DataType());
// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
int64_t input_axis_pitch = 1;
for (int i = int(data_n.Shape().NumDimensions()); i-- > axis;)
input_axis_pitch *= data_n.Shape()[i];
const auto& data_dims = data_n.Shape().GetDims();
for (int i = static_cast<int>(inputs_0_rank); i-- > axis;)
input_axis_pitch *= data_dims[i];
p.inputs.push_back({&data_n, input_axis_pitch});
p.inputs.push_back({&data_n, tensor_num_elements[input_index], input_axis_pitch});
}
return Status::OK();
@ -81,15 +105,23 @@ Status Concat::Compute(OpKernelContext* ctx) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
// return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
auto is_string_type = ctx->Input<Tensor>(0)->DataType() == DataTypeImpl::GetType<std::string>();
int64_t output_offset = 0;
auto element_bytes = p.output_tensor->DataType()->Size();
for (int input_index = 0; input_index < input_count; input_index++) {
const auto& prep = p.inputs[input_index];
// no data in this tensor - so skip it
if (prep.num_elements == 0)
continue;
auto input_axis_pitch = prep.axis_pitch;
const uint8_t* input = static_cast<const uint8_t*>(prep.tensor->DataRaw());
auto input_size = prep.tensor->Shape().Size();
auto input_size = prep.num_elements;
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());

View file

@ -21,9 +21,11 @@ class ConcatBase {
struct Prepare {
struct InputInfo {
const Tensor* tensor;
size_t num_elements;
int64_t axis_pitch;
};
std::vector<InputInfo> inputs;
size_t output_num_elements;
int64_t output_axis_pitch;
Tensor* output_tensor;
};

View file

@ -20,11 +20,17 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
// Return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
int64_t output_offset = 0;
auto element_bytes = p.output_tensor->DataType()->Size();
for (int input_index = 0; input_index < input_count; input_index++) {
const auto& prep = p.inputs[input_index];
// No data in this tensor - so skip it
if (prep.num_elements == 0)
continue;
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync(
static_cast<uint8_t*>(p.output_tensor->MutableDataRaw()) + output_offset * element_bytes,
@ -32,7 +38,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
prep.tensor->DataRaw(),
prep.axis_pitch * element_bytes,
prep.axis_pitch * element_bytes,
prep.tensor->Shape().Size() / prep.axis_pitch,
prep.num_elements / prep.axis_pitch,
cudaMemcpyDeviceToDevice));
output_offset += prep.axis_pitch;

View file

@ -40,7 +40,7 @@ TEST(MathOpTest, Concat1D_int32_negative_axis) {
test.Run();
}
TEST(MathOpTest, Concat1D) {
TEST(MathOpTest, Concat1D_1) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});
@ -51,6 +51,17 @@ TEST(MathOpTest, Concat1D) {
test.Run();
}
TEST(MathOpTest, Concat1D_2) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});
test.AddInput<float>("input1", {1}, {1.0f});
test.AddInput<float>("input2", {2}, {2.0f, 3.0f});
test.AddInput<float>("input3", {0}, {});
test.AddOutput<float>("concat_result", {3}, {1.0f, 2.0f, 3.0f});
test.Run();
}
TEST(MathOpTest, Concat2D_1) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});
@ -82,6 +93,17 @@ TEST(MathOpTest, Concat2D_2) {
test.Run();
}
TEST(MathOpTest, Concat2D_3) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});
test.AddInput<float>("input1", {1, 0}, {});
test.AddInput<float>("input2", {1, 0}, {});
test.AddInput<float>("input3", {1, 0}, {});
test.AddOutput<float>("concat_result", {1, 0}, {});
test.Run();
}
TEST(MathOpTest, Concat3D_1) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});