tile op: make implementation type-agnostic (and support a few more types) (#688)

* Initial commit

* PR feedback

* PR feedback
This commit is contained in:
Hariharan Seshadri 2019-03-25 11:55:51 -07:00 committed by GitHub
parent 6497f0c133
commit c8f1da28c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 173 additions and 95 deletions

View file

@ -24,29 +24,82 @@ namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
Tile,
6,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Tile<float>);
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<bool>()}),
Tile);
template <>
Status Tile<float>::Compute(OpKernelContext* ctx) const {
Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tensor, const int64_t* repeats, TensorAxisCounters& input_counters, const TensorPitches& output_pitches, size_t element_size) {
const auto& input_shape = input_tensor.Shape().GetDims();
const size_t dimension_count = input_shape.size();
const uint8_t* input = reinterpret_cast<const uint8_t*>(input_tensor.DataRaw());
uint8_t* output = reinterpret_cast<uint8_t*>(output_tensor.MutableDataRaw());
// some helper variables that will be used along the way
size_t block_size = 0;
int64_t num_repeats = 0;
const uint8_t* copy = nullptr;
const int64_t innermost_dim = input_shape[dimension_count - 1];
while (input_counters) {
// Copy the input data over
block_size = innermost_dim * element_size;
memcpy(output, input, block_size);
output += block_size;
input += block_size;
// Tile data for the innermost axis
copy = output - block_size;
num_repeats = repeats[dimension_count - 1] - 1;
for (int64_t repeat = 0; repeat < num_repeats; ++repeat) {
memcpy(output, copy, block_size);
output += block_size;
}
// Tile data for other axes
while (input_counters.Increment()) {
ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_shape[input_counters.Axis()];
block_size = pitch * element_size;
copy = output - block_size;
num_repeats = repeats[input_counters.Axis()] - 1;
for (int64_t repeat = 0; repeat < num_repeats; ++repeat) {
memcpy(output, copy, block_size);
output += block_size;
}
}
}
return Status::OK();
}
Status Tile::Compute(OpKernelContext* ctx) const {
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");
const Tensor& input_tensor = *tensor_pointer;
const auto& input_shape = input_tensor.Shape();
const size_t input_rank = input_shape.NumDimensions();
tensor_pointer = ctx->Input<Tensor>(1);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty");
const Tensor& repeats_tensor = *tensor_pointer;
size_t dimension_count = input_tensor.Shape().NumDimensions();
if (input_rank < 1)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "the tensor to be tiled using Tile OP must be atleast 1 dimensional");
if (repeats_tensor.Shape().NumDimensions() != 1)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional");
if (size_t(repeats_tensor.Shape().Size()) != input_tensor.Shape().NumDimensions())
if (size_t(repeats_tensor.Shape().Size()) != input_rank)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");
// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
std::vector<int64_t> output_dims = input_tensor.Shape().GetDims();
for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) {
std::vector<int64_t> output_dims = input_shape.GetDims();
for (auto axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
}
@ -60,32 +113,36 @@ Status Tile<float>::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
auto* output = output_tensor.template MutableData<float>();
auto* input = input_tensor.template Data<float>();
TensorPitches output_pitches(output_tensor);
const auto& dtype = input_tensor.DataType();
TensorAxisCounters input_counters(input_tensor);
TensorPitches output_pitches(output_tensor);
while (input_counters) {
// Copy the input data over
size_t input_pitch = input_tensor.Shape().GetDims().back();
for (size_t i = 0; i < input_pitch; i++)
*output++ = *input++;
static_assert(sizeof(float) == sizeof(int32_t), "Float and Int32 are of different sizes");
static_assert(sizeof(double) == sizeof(int64_t), "Double and Int64 are of different sizes");
// Tile it for the innermost axis
const auto* copy = output - input_tensor.Shape()[dimension_count - 1];
for (int64_t repeat = (repeats[dimension_count - 1] - 1) * input_pitch; repeat-- > 0;)
*output++ = *copy++;
if (dtype == DataTypeImpl::GetType<float>() ||
dtype == DataTypeImpl::GetType<int32_t>() ||
dtype == DataTypeImpl::GetType<uint32_t>())
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(float));
// Tile it in the other axes
while (input_counters.Increment()) {
ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_tensor.Shape()[input_counters.Axis()];
copy = output - pitch;
for (int64_t repeat = (repeats[input_counters.Axis()] - 1) * pitch; repeat-- > 0;) {
*output++ = *copy++;
}
}
}
return Status::OK();
else if (dtype == DataTypeImpl::GetType<double>() ||
dtype == DataTypeImpl::GetType<int64_t>() ||
dtype == DataTypeImpl::GetType<uint64_t>())
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(double));
else if (dtype == DataTypeImpl::GetType<int8_t>() ||
dtype == DataTypeImpl::GetType<uint8_t>())
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int8_t));
else if (dtype == DataTypeImpl::GetType<int16_t>() ||
dtype == DataTypeImpl::GetType<uint16_t>())
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int16_t));
else if (dtype == DataTypeImpl::GetType<bool>())
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(bool));
// TODO: Support 'string' and 'float16' types for completeness
else
ORT_THROW("Tile doesn't have an implementation yet for the type: ", dtype);
}
} // namespace onnxruntime

View file

@ -7,7 +7,6 @@
namespace onnxruntime {
template <typename T>
struct Tile final : OpKernel {
Tile(const OpKernelInfo& info) : OpKernel(info) {
}

View file

@ -7,82 +7,104 @@
namespace onnxruntime {
namespace test {
TEST(TensorOpTest, Tile1DWithZeroRepeats) {
template <typename T>
void RunTest(std::initializer_list<T> input,
std::initializer_list<int64_t> input_dims,
std::initializer_list<int64_t> repeat,
std::initializer_list<int64_t> repeat_dims,
std::initializer_list<T> output,
std::initializer_list<int64_t> output_dims) {
OpTester test("Tile");
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
test.AddInput<int64_t>("repeats", {1}, {0});
test.AddOutput<float>("output", {0}, {});
test.AddInput<T>("input", input_dims, input);
test.AddInput<int64_t>("repeats", repeat_dims, repeat);
test.AddOutput<T>("output", output_dims, output);
test.Run();
}
TEST(TensorOpTest, Tile2DWithZeroRepeats) {
OpTester test("Tile");
template <typename T>
void RunTestWrapper() {
// Tile1DWithZeroRepeats
RunTest<T>({1, 2, 3}, {3}, {0}, {1}, {}, {0});
test.AddInput<float>("input", {2, 2},
{11.0f, 12.0f,
21.0f, 22.0f});
test.AddInput<int64_t>("repeats", {2}, {2, 0});
test.AddOutput<float>("output", {4, 0}, {});
test.Run();
// Tile2DWithZeroRepeats
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 0}, {2}, {}, {4, 0});
// Tile1D
RunTest<T>({1, 2, 3}, {3}, {3}, {1}, {1, 2, 3, 1, 2, 3, 1, 2, 3}, {9});
// Tile2D_1Axis
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 1}, {2}, {11, 12, 21, 22, 11, 12, 21, 22}, {4, 2});
// Tile2D_2Axes
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 2}, {2}, {11, 12, 11, 12, 21, 22, 21, 22, 11, 12, 11, 12, 21, 22, 21, 22}, {4, 4});
// Tile3D
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});
}
TEST(TensorOpTest, Tile1D) {
OpTester test("Tile");
template <>
void RunTestWrapper<bool>() {
// Tile1DWithZeroRepeats
RunTest<bool>({true, false, true}, {3}, {0}, {1}, {}, {0});
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
test.AddInput<int64_t>("repeats", {1}, {3});
test.AddOutput<float>("output", {9}, {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f});
test.Run();
// Tile2DWithZeroRepeats
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 0}, {2}, {}, {4, 0});
// Tile1D
RunTest<bool>({true, false, true}, {3}, {3}, {1}, {true, false, true, true, false, true, true, false, true}, {9});
// Tile2D_1Axis
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 1}, {2}, {true, false, true, false, true, false, true, false}, {4, 2});
// Tile2D_2Axes
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 2}, {2}, {true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}, {4, 4});
// Tile3D
RunTest<bool>({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3});
}
TEST(TensorOpTest, Tile2D_1Axis) {
OpTester test("Tile");
test.AddInput<float>("input", {2, 2},
{11.0f, 12.0f,
21.0f, 22.0f});
test.AddInput<int64_t>("repeats", {2}, {2, 1});
test.AddOutput<float>("output", {4, 2},
{11.0f, 12.0f,
21.0f, 22.0f,
11.0f, 12.0f,
21.0f, 22.0f});
test.Run();
TEST(TensorOpTest, TileFloatType) {
RunTestWrapper<float>();
}
TEST(TensorOpTest, Tile2D_2Axes) {
OpTester test("Tile");
test.AddInput<float>("input", {2, 2},
{11.0f, 12.0f,
21.0f, 22.0f});
test.AddInput<int64_t>("repeats", {2}, {2, 2});
test.AddOutput<float>("output", {4, 4},
{11.0f, 12.0f, 11.0f, 12.0f,
21.0f, 22.0f, 21.0f, 22.0f,
11.0f, 12.0f, 11.0f, 12.0f,
21.0f, 22.0f, 21.0f, 22.0f});
test.Run();
TEST(TensorOpTest, TileDoubleType) {
RunTestWrapper<double>();
}
TEST(TensorOpTest, Tile3D) {
OpTester test("Tile");
test.AddInput<float>("input", {2, 1, 3},
{111.0f, 112.0f, 113.0f,
211.0f, 212.0f, 213.0f});
test.AddInput<int64_t>("repeats", {3}, {1, 2, 1});
test.AddOutput<float>("output", {2, 2, 3},
{111.0f, 112.0f, 113.0f,
111.0f, 112.0f, 113.0f,
211.0f, 212.0f, 213.0f,
211.0f, 212.0f, 213.0f});
test.Run();
TEST(TensorOpTest, TileInt8Type) {
RunTestWrapper<int8_t>();
}
TEST(TensorOpTest, TileUint8Type) {
RunTestWrapper<uint8_t>();
}
TEST(TensorOpTest, TileInt16Type) {
RunTestWrapper<int16_t>();
}
TEST(TensorOpTest, TileUint16Type) {
RunTestWrapper<uint16_t>();
}
TEST(TensorOpTest, TileInt32Type) {
RunTestWrapper<int32_t>();
}
TEST(TensorOpTest, TileUint32Type) {
RunTestWrapper<uint32_t>();
}
TEST(TensorOpTest, TileInt64Type) {
RunTestWrapper<int64_t>();
}
TEST(TensorOpTest, TileUint64Type) {
RunTestWrapper<uint64_t>();
}
TEST(TensorOpTest, TileBoolType) {
RunTestWrapper<bool>();
}
} // namespace test
} // namespace onnxruntime