mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
tile op: make implementation type-agnostic (and support a few more types) (#688)
* Initial commit * PR feedback * PR feedback
This commit is contained in:
parent
6497f0c133
commit
c8f1da28c4
3 changed files with 173 additions and 95 deletions
|
|
@ -24,29 +24,82 @@ namespace onnxruntime {
|
|||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Tile,
|
||||
6,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Tile<float>);
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<bool>()}),
|
||||
Tile);
|
||||
|
||||
template <>
|
||||
Status Tile<float>::Compute(OpKernelContext* ctx) const {
|
||||
Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tensor, const int64_t* repeats, TensorAxisCounters& input_counters, const TensorPitches& output_pitches, size_t element_size) {
|
||||
const auto& input_shape = input_tensor.Shape().GetDims();
|
||||
const size_t dimension_count = input_shape.size();
|
||||
|
||||
const uint8_t* input = reinterpret_cast<const uint8_t*>(input_tensor.DataRaw());
|
||||
uint8_t* output = reinterpret_cast<uint8_t*>(output_tensor.MutableDataRaw());
|
||||
|
||||
// some helper variables that will be used along the way
|
||||
size_t block_size = 0;
|
||||
int64_t num_repeats = 0;
|
||||
const uint8_t* copy = nullptr;
|
||||
const int64_t innermost_dim = input_shape[dimension_count - 1];
|
||||
|
||||
while (input_counters) {
|
||||
// Copy the input data over
|
||||
block_size = innermost_dim * element_size;
|
||||
memcpy(output, input, block_size);
|
||||
output += block_size;
|
||||
input += block_size;
|
||||
|
||||
// Tile data for the innermost axis
|
||||
copy = output - block_size;
|
||||
num_repeats = repeats[dimension_count - 1] - 1;
|
||||
for (int64_t repeat = 0; repeat < num_repeats; ++repeat) {
|
||||
memcpy(output, copy, block_size);
|
||||
output += block_size;
|
||||
}
|
||||
|
||||
// Tile data for other axes
|
||||
while (input_counters.Increment()) {
|
||||
ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_shape[input_counters.Axis()];
|
||||
block_size = pitch * element_size;
|
||||
copy = output - block_size;
|
||||
num_repeats = repeats[input_counters.Axis()] - 1;
|
||||
for (int64_t repeat = 0; repeat < num_repeats; ++repeat) {
|
||||
memcpy(output, copy, block_size);
|
||||
output += block_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tile::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");
|
||||
const Tensor& input_tensor = *tensor_pointer;
|
||||
const auto& input_shape = input_tensor.Shape();
|
||||
const size_t input_rank = input_shape.NumDimensions();
|
||||
tensor_pointer = ctx->Input<Tensor>(1);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty");
|
||||
const Tensor& repeats_tensor = *tensor_pointer;
|
||||
|
||||
size_t dimension_count = input_tensor.Shape().NumDimensions();
|
||||
|
||||
if (input_rank < 1)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "the tensor to be tiled using Tile OP must be atleast 1 dimensional");
|
||||
if (repeats_tensor.Shape().NumDimensions() != 1)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional");
|
||||
if (size_t(repeats_tensor.Shape().Size()) != input_tensor.Shape().NumDimensions())
|
||||
if (size_t(repeats_tensor.Shape().Size()) != input_rank)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");
|
||||
|
||||
// Calculate the shape of the output tensor
|
||||
auto* repeats = repeats_tensor.template Data<int64_t>();
|
||||
std::vector<int64_t> output_dims = input_tensor.Shape().GetDims();
|
||||
for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) {
|
||||
std::vector<int64_t> output_dims = input_shape.GetDims();
|
||||
for (auto axis = 0; axis < input_rank; axis++) {
|
||||
output_dims[axis] *= repeats[axis];
|
||||
}
|
||||
|
||||
|
|
@ -60,32 +113,36 @@ Status Tile<float>::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
auto* output = output_tensor.template MutableData<float>();
|
||||
auto* input = input_tensor.template Data<float>();
|
||||
|
||||
TensorPitches output_pitches(output_tensor);
|
||||
const auto& dtype = input_tensor.DataType();
|
||||
TensorAxisCounters input_counters(input_tensor);
|
||||
TensorPitches output_pitches(output_tensor);
|
||||
|
||||
while (input_counters) {
|
||||
// Copy the input data over
|
||||
size_t input_pitch = input_tensor.Shape().GetDims().back();
|
||||
for (size_t i = 0; i < input_pitch; i++)
|
||||
*output++ = *input++;
|
||||
static_assert(sizeof(float) == sizeof(int32_t), "Float and Int32 are of different sizes");
|
||||
static_assert(sizeof(double) == sizeof(int64_t), "Double and Int64 are of different sizes");
|
||||
|
||||
// Tile it for the innermost axis
|
||||
const auto* copy = output - input_tensor.Shape()[dimension_count - 1];
|
||||
for (int64_t repeat = (repeats[dimension_count - 1] - 1) * input_pitch; repeat-- > 0;)
|
||||
*output++ = *copy++;
|
||||
if (dtype == DataTypeImpl::GetType<float>() ||
|
||||
dtype == DataTypeImpl::GetType<int32_t>() ||
|
||||
dtype == DataTypeImpl::GetType<uint32_t>())
|
||||
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(float));
|
||||
|
||||
// Tile it in the other axes
|
||||
while (input_counters.Increment()) {
|
||||
ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_tensor.Shape()[input_counters.Axis()];
|
||||
copy = output - pitch;
|
||||
for (int64_t repeat = (repeats[input_counters.Axis()] - 1) * pitch; repeat-- > 0;) {
|
||||
*output++ = *copy++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
else if (dtype == DataTypeImpl::GetType<double>() ||
|
||||
dtype == DataTypeImpl::GetType<int64_t>() ||
|
||||
dtype == DataTypeImpl::GetType<uint64_t>())
|
||||
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(double));
|
||||
|
||||
else if (dtype == DataTypeImpl::GetType<int8_t>() ||
|
||||
dtype == DataTypeImpl::GetType<uint8_t>())
|
||||
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int8_t));
|
||||
|
||||
else if (dtype == DataTypeImpl::GetType<int16_t>() ||
|
||||
dtype == DataTypeImpl::GetType<uint16_t>())
|
||||
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int16_t));
|
||||
|
||||
else if (dtype == DataTypeImpl::GetType<bool>())
|
||||
return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(bool));
|
||||
|
||||
// TODO: Support 'string' and 'float16' types for completeness
|
||||
else
|
||||
ORT_THROW("Tile doesn't have an implementation yet for the type: ", dtype);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
struct Tile final : OpKernel {
|
||||
Tile(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,82 +7,104 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(TensorOpTest, Tile1DWithZeroRepeats) {
|
||||
template <typename T>
|
||||
void RunTest(std::initializer_list<T> input,
|
||||
std::initializer_list<int64_t> input_dims,
|
||||
std::initializer_list<int64_t> repeat,
|
||||
std::initializer_list<int64_t> repeat_dims,
|
||||
std::initializer_list<T> output,
|
||||
std::initializer_list<int64_t> output_dims) {
|
||||
OpTester test("Tile");
|
||||
|
||||
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
|
||||
test.AddInput<int64_t>("repeats", {1}, {0});
|
||||
test.AddOutput<float>("output", {0}, {});
|
||||
test.AddInput<T>("input", input_dims, input);
|
||||
test.AddInput<int64_t>("repeats", repeat_dims, repeat);
|
||||
test.AddOutput<T>("output", output_dims, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, Tile2DWithZeroRepeats) {
|
||||
OpTester test("Tile");
|
||||
template <typename T>
|
||||
void RunTestWrapper() {
|
||||
// Tile1DWithZeroRepeats
|
||||
RunTest<T>({1, 2, 3}, {3}, {0}, {1}, {}, {0});
|
||||
|
||||
test.AddInput<float>("input", {2, 2},
|
||||
{11.0f, 12.0f,
|
||||
21.0f, 22.0f});
|
||||
test.AddInput<int64_t>("repeats", {2}, {2, 0});
|
||||
test.AddOutput<float>("output", {4, 0}, {});
|
||||
test.Run();
|
||||
// Tile2DWithZeroRepeats
|
||||
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 0}, {2}, {}, {4, 0});
|
||||
|
||||
// Tile1D
|
||||
RunTest<T>({1, 2, 3}, {3}, {3}, {1}, {1, 2, 3, 1, 2, 3, 1, 2, 3}, {9});
|
||||
|
||||
// Tile2D_1Axis
|
||||
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 1}, {2}, {11, 12, 21, 22, 11, 12, 21, 22}, {4, 2});
|
||||
|
||||
// Tile2D_2Axes
|
||||
RunTest<T>({11, 12, 21, 22}, {2, 2}, {2, 2}, {2}, {11, 12, 11, 12, 21, 22, 21, 22, 11, 12, 11, 12, 21, 22, 21, 22}, {4, 4});
|
||||
|
||||
// Tile3D
|
||||
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, Tile1D) {
|
||||
OpTester test("Tile");
|
||||
template <>
|
||||
void RunTestWrapper<bool>() {
|
||||
// Tile1DWithZeroRepeats
|
||||
RunTest<bool>({true, false, true}, {3}, {0}, {1}, {}, {0});
|
||||
|
||||
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
|
||||
test.AddInput<int64_t>("repeats", {1}, {3});
|
||||
test.AddOutput<float>("output", {9}, {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f});
|
||||
test.Run();
|
||||
// Tile2DWithZeroRepeats
|
||||
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 0}, {2}, {}, {4, 0});
|
||||
|
||||
// Tile1D
|
||||
RunTest<bool>({true, false, true}, {3}, {3}, {1}, {true, false, true, true, false, true, true, false, true}, {9});
|
||||
|
||||
// Tile2D_1Axis
|
||||
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 1}, {2}, {true, false, true, false, true, false, true, false}, {4, 2});
|
||||
|
||||
// Tile2D_2Axes
|
||||
RunTest<bool>({true, false, true, false}, {2, 2}, {2, 2}, {2}, {true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}, {4, 4});
|
||||
|
||||
// Tile3D
|
||||
RunTest<bool>({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3});
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, Tile2D_1Axis) {
|
||||
OpTester test("Tile");
|
||||
|
||||
test.AddInput<float>("input", {2, 2},
|
||||
{11.0f, 12.0f,
|
||||
21.0f, 22.0f});
|
||||
test.AddInput<int64_t>("repeats", {2}, {2, 1});
|
||||
test.AddOutput<float>("output", {4, 2},
|
||||
{11.0f, 12.0f,
|
||||
21.0f, 22.0f,
|
||||
11.0f, 12.0f,
|
||||
21.0f, 22.0f});
|
||||
|
||||
test.Run();
|
||||
TEST(TensorOpTest, TileFloatType) {
|
||||
RunTestWrapper<float>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, Tile2D_2Axes) {
|
||||
OpTester test("Tile");
|
||||
|
||||
test.AddInput<float>("input", {2, 2},
|
||||
{11.0f, 12.0f,
|
||||
21.0f, 22.0f});
|
||||
test.AddInput<int64_t>("repeats", {2}, {2, 2});
|
||||
test.AddOutput<float>("output", {4, 4},
|
||||
{11.0f, 12.0f, 11.0f, 12.0f,
|
||||
21.0f, 22.0f, 21.0f, 22.0f,
|
||||
11.0f, 12.0f, 11.0f, 12.0f,
|
||||
21.0f, 22.0f, 21.0f, 22.0f});
|
||||
|
||||
test.Run();
|
||||
TEST(TensorOpTest, TileDoubleType) {
|
||||
RunTestWrapper<double>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, Tile3D) {
|
||||
OpTester test("Tile");
|
||||
|
||||
test.AddInput<float>("input", {2, 1, 3},
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
211.0f, 212.0f, 213.0f});
|
||||
test.AddInput<int64_t>("repeats", {3}, {1, 2, 1});
|
||||
test.AddOutput<float>("output", {2, 2, 3},
|
||||
{111.0f, 112.0f, 113.0f,
|
||||
111.0f, 112.0f, 113.0f,
|
||||
|
||||
211.0f, 212.0f, 213.0f,
|
||||
211.0f, 212.0f, 213.0f});
|
||||
test.Run();
|
||||
TEST(TensorOpTest, TileInt8Type) {
|
||||
RunTestWrapper<int8_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileUint8Type) {
|
||||
RunTestWrapper<uint8_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileInt16Type) {
|
||||
RunTestWrapper<int16_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileUint16Type) {
|
||||
RunTestWrapper<uint16_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileInt32Type) {
|
||||
RunTestWrapper<int32_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileUint32Type) {
|
||||
RunTestWrapper<uint32_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileInt64Type) {
|
||||
RunTestWrapper<int64_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileUint64Type) {
|
||||
RunTestWrapper<uint64_t>();
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, TileBoolType) {
|
||||
RunTestWrapper<bool>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue