Yanchen/nuphar/scatter elems (#1992)

* Added Scatter and ScatterElements to Nuphar

Implemented Scatter (op_ver 9 - 10) and ScatterElements (op_ver 11)
nuphar.

Because TVM's compute is output-oriented, our current implementation
uses extern calls for simplicity.

* fixed build issue after rebase

* remove dead code

* Address CR

* removed dead code

* use GetAttrOrDefault

* Address more CR feedback

* add GetStrides to codegen/common/utils.h

* added a unit test for Bool input data
This commit is contained in:
Yang Chen 2019-10-03 14:58:10 -07:00 committed by GitHub
parent c86d17754a
commit 15138908e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 327 additions and 0 deletions

View file

@ -52,4 +52,16 @@ int64_t TotalSize(const std::vector<int64_t>& shape) {
return total;
}
// Return the strides for the input shape, i.e. the number of
// elements contained by a single element of current dimension.
// For example, for shape[3][4][5][6], strides will be
// [4*5*6, 5*6, 6, 1], i.e. [120, 30, 6, 1]
void GetStrides(const int64_t* shape, int ndim, std::vector<int64_t>& strides) {
strides.resize(ndim);
strides[ndim - 1] = 1;
for (int64_t i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i+1] * shape[i+1];
}
}
} // namespace onnxruntime

View file

@ -17,4 +17,6 @@ bool IsEnvVarDefined(const char* var);
int64_t TotalSize(const std::vector<int64_t>& shape);
void GetStrides(const int64_t* shape, int ndim, std::vector<int64_t>& strides);
} // namespace onnxruntime

View file

@ -52,6 +52,8 @@ namespace nuphar {
ADD_OP_ITEM(MatMul) \
ADD_OP_ITEM(MatMulInteger) \
ADD_OP_ITEM(MatMulInteger16) \
ADD_OP_ITEM(Scatter) \
ADD_OP_ITEM(ScatterElements) \
ADD_OP_ITEM(Slice) \
ADD_OP_ITEM(Softmax) \
ADD_OP_ITEM(Tile)

View file

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h"
#include "core/framework/op_kernel_info.h"
#include "core/providers/nuphar/mti_x86/tensor/scatter.h"
namespace onnxruntime {
namespace nuphar {
static Status ScatterCommon(
const tvm::Array<tvm::Tensor>& inputs,
const Node& node,
tvm_codegen::CodeGenContext&,
tvm::Array<tvm::Tensor>& outputs,
const std::string& name) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
// The default value of optional attribute axis is 0
int64_t axis = attrs.GetAttrOrDefault<int64_t>("axis", 0);
tvm::Tensor Y = Scatter(inputs[0], axis, inputs[1], inputs[2], name);
outputs.push_back(Y);
return Status::OK();
}
// Evaluate of Scatter OpIRCreator
Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Scatter)::Evaluate(
const tvm::Array<tvm::Tensor>& inputs,
const Node& node,
tvm_codegen::CodeGenContext& codegen_ctx,
tvm::Array<tvm::Tensor>& outputs) {
return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_ScatterElements");
}
// Evaluate of ScatterElements OpIRCreator
Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(ScatterElements)::Evaluate(
const tvm::Array<tvm::Tensor>& inputs,
const Node& node,
tvm_codegen::CodeGenContext& codegen_ctx,
tvm::Array<tvm::Tensor>& outputs) {
return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_Scatter");
}
} // namespace nuphar
} // namespace onnxruntime

View file

@ -227,4 +227,27 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
nuphar::NupharKernel);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Scatter,
kOnnxDomain,
9,
10,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
ScatterElements,
kOnnxDomain,
11,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);
} // namespace onnxruntime

View file

@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/nuphar/mti_x86/tensor/scatter.h"
#include "core/codegen/common/utils.h"
#include "core/codegen/mti/mti_tvm_utils.h"
#include "core/common/common.h"
#include <topi/detail/extern.h>
namespace onnxruntime {
namespace nuphar {
#define STRINGIFY_1(x) #x
#define STRINGIFY(x) STRINGIFY_1(x)
#define GET_EXTERN_SCATTER_STR(input_type, index_type) \
STRINGIFY(tvm.contrib.onnxruntime.scatter_##input_type##index_type)
static int64_t DLTensorSize(const DLTensor* dl_tensor) {
int64_t sz = 1;
for (int i = 0; i < dl_tensor->ndim; ++i) {
sz *= dl_tensor->shape[i];
}
return sz;
}
template<class T, class Tind>
void ScatterCommon(tvm::TVMArgs args, tvm::TVMRetValue* /*ret*/) {
DLTensor* input = args[0];
DLTensor* indices = args[1];
DLTensor* updates = args[2];
DLTensor* output = args[3];
int axis = args[4];
int num_dims = input->ndim;
DCHECK(axis < num_dims);
for (int i = 0; i < num_dims; i++) {
if (indices->shape[i] != updates->shape[i]) {
ORT_THROW("Indices vs updates dimensions differs at position=", i,
" ", indices->shape[i], " vs ", updates->shape[i]);
}
}
// extract indices from raw data
Tind* indices_data = reinterpret_cast<Tind*>(static_cast<char*>(indices->data) + indices->byte_offset);
int64_t indices_size = DLTensorSize(indices);
std::vector<Tind> indices_data_vec(indices_size);
int64_t axis_size = input->shape[axis];
for (int64_t i = 0; i < indices_size; i++) {
Tind idx = indices_data[i];
// indices can be negative values
if (idx >= -axis_size && idx < axis_size) {
indices_data_vec[i] = idx >= 0 ? idx : idx + static_cast<Tind>(axis_size);
} else {
ORT_THROW("indices element out of data bounds, idx=", idx,
" must be within the inclusive range [", -axis_size,
",", axis_size - 1, "]");
}
}
// copy input data into output
int64_t input_size = DLTensorSize(input);
memcpy(static_cast<char*>(output->data) + output->byte_offset,
static_cast<char*>(input->data) + input->byte_offset,
input_size * input->dtype.bits / 8);
std::vector<int64_t> input_strides;
GetStrides(input->shape, num_dims, input_strides);
T* output_data = reinterpret_cast<T*>(static_cast<char*>(output->data) + output->byte_offset);
T* updates_data = reinterpret_cast<T*>(static_cast<char*>(updates->data) + updates->byte_offset);
const std::vector<int> indices_shape(indices->shape, indices->shape + num_dims);
// Because indices data is flat, running_indices maintains indices's original dimensions.
// We will use its dimensions to compute the corresponding index (or offset) to output_data,
// which is also flat.
std::vector<int64_t> running_indices(num_dims, 0);
for (int64_t i = 0; i < indices_size; i++) {
Tind idx = indices_data_vec[i];
// output indices come from running_indices
std::vector<int64_t> curr_output_indices = running_indices;
curr_output_indices[axis] = static_cast<int64_t>(idx);
// get the index into output_data
int64_t output_idx = 0;
for (int j = 0; j < num_dims; j++) {
output_idx += curr_output_indices[j] * input_strides[j];
}
// update data
output_data[output_idx] = updates_data[i];
// update running_indices
Tind carry = 1;
for (int j = num_dims - 1; j >= 0; j--) {
if (carry == 0) break;
Tind curr_idx = running_indices[j] + carry;
running_indices[j] = curr_idx % indices_shape[j];
carry = curr_idx / indices_shape[j];
}
}
}
#define REGISTER_EXTERN_SCATTER(input_type, index_type) \
TVM_REGISTER_GLOBAL(GET_EXTERN_SCATTER_STR(input_type, index_type)) \
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { \
ScatterCommon<input_type, index_type>(args, ret); \
});
#define REGISTER_EXTERN_SCATTER_PAIR(input_type) \
REGISTER_EXTERN_SCATTER(input_type, int32_t) \
REGISTER_EXTERN_SCATTER(input_type, int64_t) \
REGISTER_EXTERN_SCATTER_PAIR(bool)
REGISTER_EXTERN_SCATTER_PAIR(int8_t)
REGISTER_EXTERN_SCATTER_PAIR(uint8_t)
REGISTER_EXTERN_SCATTER_PAIR(int16_t)
REGISTER_EXTERN_SCATTER_PAIR(uint16_t)
REGISTER_EXTERN_SCATTER_PAIR(int32_t)
REGISTER_EXTERN_SCATTER_PAIR(uint32_t)
REGISTER_EXTERN_SCATTER_PAIR(int64_t)
REGISTER_EXTERN_SCATTER_PAIR(uint64_t)
REGISTER_EXTERN_SCATTER_PAIR(float)
REGISTER_EXTERN_SCATTER_PAIR(double)
#undef REGISTER_EXTERN_SCATTER
static tvm::Tensor MakeExternScatter(const tvm::Tensor& t,
int64_t axis_p,
const tvm::Tensor& indices,
const tvm::Tensor& updates,
const std::string& name,
const char* extern_scatter) {
// handle negative axis
int64_t input_rank = static_cast<int64_t>(t->shape.size());
DCHECK(input_rank >= 1);
DCHECK(input_rank == static_cast<int64_t>(indices->shape.size()));
DCHECK(input_rank == static_cast<int64_t>(updates->shape.size()));
int axis = static_cast<int>(tvm_codegen::HandleNegativeAxis(axis_p, input_rank));
// output has the same shape as input
tvm::Array<tvm::Expr> output_shape;
for (int64_t i = 0; i < input_rank; i++) {
output_shape.push_back(t->shape[i]);
}
return topi::detail::make_extern(
/*output_shapes*/ {output_shape},
/*output_types*/ {t->dtype},
/*inputs*/ {t, indices, updates},
[&](tvm::Array<tvm::Buffer> ins, tvm::Array<tvm::Buffer> outs) {
tvm::Array<tvm::Expr> args = {tvm::Expr(extern_scatter),
topi::detail::pack_buffer(ins[0]),
topi::detail::pack_buffer(ins[1]),
topi::detail::pack_buffer(ins[2]),
topi::detail::pack_buffer(outs[0]),
axis};
return topi::detail::call_packed(args);
},
name, /*tag*/ "", /*attrs*/ {})[0];
}
tvm::Tensor Scatter(const tvm::Tensor& t,
int64_t axis_p,
const tvm::Tensor& indices,
const tvm::Tensor& updates,
const std::string& name) {
#define MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, index_tensor_type, input_type, index_type) \
if (t->dtype == input_tensor_type && indices->dtype == index_tensor_type) { \
return MakeExternScatter(t, axis_p, indices, updates, name, GET_EXTERN_SCATTER_STR(input_type, index_type)); \
}
#define MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(input_tensor_type, input_type) \
MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(32), input_type, int32_t) \
MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(64), input_type, int64_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Bool(), bool)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(8), int8_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(8), uint8_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(16), int16_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(16), uint16_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(32), int32_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(32), uint32_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(64), int64_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(64), uint64_t)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(32), float)
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(64), double)
#undef MAKE_EXTERN_SCATTER_PAIR_IF_MATCH
#undef MAKE_EXTERN_SCATTER_IF_MATCH
ORT_NOT_IMPLEMENTED("input type is not implementated");
}
#undef STRINGIFY_1
#undef STRINGIFY
#undef GET_EXTERN_SCATTER_STR
} // namespace tvm_codegen
} // namespace onnxruntime

View file

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <tvm/tvm.h>
namespace onnxruntime {
namespace nuphar {
tvm::Tensor Scatter(const tvm::Tensor& t,
int64_t axis,
const tvm::Tensor& indices,
const tvm::Tensor& updates,
const std::string& name = "scatter");
} // namespace nuphar
} // namespace onnxruntime

View file

@ -394,6 +394,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements);
static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
#define NUPHAR_OP(name, ver, types) \
@ -414,6 +416,8 @@ static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements)>());
}
std::shared_ptr<KernelRegistry> NupharExecutionProvider::GetKernelRegistryInternal() const {

View file

@ -206,5 +206,21 @@ TEST(Scatter, ValidNegativeIndex) {
scatter_valid_negative_index("ScatterElements", 11);
}
static void scatter_bool_with_axis_tests(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<bool>("data", {1, 5}, {false, false, false, true, false});
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
test.AddInput<bool>("updates", {1, 2}, {true, false});
test.AddOutput<bool>("y", {1, 5}, {false, true, false, false, false});
test.Run();
}
TEST(Scatter, BoolInputWithAxis) {
scatter_bool_with_axis_tests("Scatter", 9);
scatter_bool_with_axis_tests("ScatterElements", 11);
}
} // namespace test
} // namespace onnxruntime