mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Yanchen/nuphar/scatter elems (#1992)
* Added Scatter and ScatterElements to Nuphar Implemented Scatter (op_ver 9 - 10) and ScatterElements (op_ver 11) nuphar. Because TVM's compute is output-oriented, our current implementation uses extern calls for simplicity. * fixed build issue after rebase * remove dead code * Address CR * removed dead code * use GetAttrOrDefault * Address more CR feedback * add GetStrides to codegen/common/utils.h * added a unit test for Bool input data
This commit is contained in:
parent
c86d17754a
commit
15138908e7
9 changed files with 327 additions and 0 deletions
|
|
@ -52,4 +52,16 @@ int64_t TotalSize(const std::vector<int64_t>& shape) {
|
|||
return total;
|
||||
}
|
||||
|
||||
// Return the strides for the input shape, i.e. the number of
|
||||
// elements contained by a single element of current dimension.
|
||||
// For example, for shape[3][4][5][6], strides will be
|
||||
// [4*5*6, 5*6, 6, 1], i.e. [120, 30, 6, 1]
|
||||
void GetStrides(const int64_t* shape, int ndim, std::vector<int64_t>& strides) {
|
||||
strides.resize(ndim);
|
||||
strides[ndim - 1] = 1;
|
||||
for (int64_t i = ndim - 2; i >= 0; i--) {
|
||||
strides[i] = strides[i+1] * shape[i+1];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -17,4 +17,6 @@ bool IsEnvVarDefined(const char* var);
|
|||
|
||||
int64_t TotalSize(const std::vector<int64_t>& shape);
|
||||
|
||||
void GetStrides(const int64_t* shape, int ndim, std::vector<int64_t>& strides);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -52,6 +52,8 @@ namespace nuphar {
|
|||
ADD_OP_ITEM(MatMul) \
|
||||
ADD_OP_ITEM(MatMulInteger) \
|
||||
ADD_OP_ITEM(MatMulInteger16) \
|
||||
ADD_OP_ITEM(Scatter) \
|
||||
ADD_OP_ITEM(ScatterElements) \
|
||||
ADD_OP_ITEM(Slice) \
|
||||
ADD_OP_ITEM(Softmax) \
|
||||
ADD_OP_ITEM(Tile)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h"
|
||||
|
||||
#include "core/framework/op_kernel_info.h"
|
||||
#include "core/providers/nuphar/mti_x86/tensor/scatter.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace nuphar {
|
||||
|
||||
static Status ScatterCommon(
|
||||
const tvm::Array<tvm::Tensor>& inputs,
|
||||
const Node& node,
|
||||
tvm_codegen::CodeGenContext&,
|
||||
tvm::Array<tvm::Tensor>& outputs,
|
||||
const std::string& name) {
|
||||
ProtoHelperNodeContext ctx(node);
|
||||
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
|
||||
|
||||
// The default value of optional attribute axis is 0
|
||||
int64_t axis = attrs.GetAttrOrDefault<int64_t>("axis", 0);
|
||||
|
||||
tvm::Tensor Y = Scatter(inputs[0], axis, inputs[1], inputs[2], name);
|
||||
outputs.push_back(Y);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Evaluate of Scatter OpIRCreator
|
||||
Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Scatter)::Evaluate(
|
||||
const tvm::Array<tvm::Tensor>& inputs,
|
||||
const Node& node,
|
||||
tvm_codegen::CodeGenContext& codegen_ctx,
|
||||
tvm::Array<tvm::Tensor>& outputs) {
|
||||
return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_ScatterElements");
|
||||
}
|
||||
|
||||
// Evaluate of ScatterElements OpIRCreator
|
||||
Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(ScatterElements)::Evaluate(
|
||||
const tvm::Array<tvm::Tensor>& inputs,
|
||||
const Node& node,
|
||||
tvm_codegen::CodeGenContext& codegen_ctx,
|
||||
tvm::Array<tvm::Tensor>& outputs) {
|
||||
return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_Scatter");
|
||||
}
|
||||
|
||||
} // namespace nuphar
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -227,4 +227,27 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
|
||||
nuphar::NupharKernel);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Scatter,
|
||||
kOnnxDomain,
|
||||
9,
|
||||
10,
|
||||
kNupharExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
nuphar::NupharKernel);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ScatterElements,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
kNupharExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
nuphar::NupharKernel);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
202
onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc
Normal file
202
onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/nuphar/mti_x86/tensor/scatter.h"
|
||||
|
||||
#include "core/codegen/common/utils.h"
|
||||
#include "core/codegen/mti/mti_tvm_utils.h"
|
||||
#include "core/common/common.h"
|
||||
#include <topi/detail/extern.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace nuphar {
|
||||
|
||||
#define STRINGIFY_1(x) #x
|
||||
#define STRINGIFY(x) STRINGIFY_1(x)
|
||||
#define GET_EXTERN_SCATTER_STR(input_type, index_type) \
|
||||
STRINGIFY(tvm.contrib.onnxruntime.scatter_##input_type##index_type)
|
||||
|
||||
static int64_t DLTensorSize(const DLTensor* dl_tensor) {
|
||||
int64_t sz = 1;
|
||||
for (int i = 0; i < dl_tensor->ndim; ++i) {
|
||||
sz *= dl_tensor->shape[i];
|
||||
}
|
||||
return sz;
|
||||
}
|
||||
|
||||
template<class T, class Tind>
|
||||
void ScatterCommon(tvm::TVMArgs args, tvm::TVMRetValue* /*ret*/) {
|
||||
DLTensor* input = args[0];
|
||||
DLTensor* indices = args[1];
|
||||
DLTensor* updates = args[2];
|
||||
DLTensor* output = args[3];
|
||||
int axis = args[4];
|
||||
|
||||
int num_dims = input->ndim;
|
||||
DCHECK(axis < num_dims);
|
||||
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
if (indices->shape[i] != updates->shape[i]) {
|
||||
ORT_THROW("Indices vs updates dimensions differs at position=", i,
|
||||
" ", indices->shape[i], " vs ", updates->shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// extract indices from raw data
|
||||
Tind* indices_data = reinterpret_cast<Tind*>(static_cast<char*>(indices->data) + indices->byte_offset);
|
||||
int64_t indices_size = DLTensorSize(indices);
|
||||
std::vector<Tind> indices_data_vec(indices_size);
|
||||
int64_t axis_size = input->shape[axis];
|
||||
for (int64_t i = 0; i < indices_size; i++) {
|
||||
Tind idx = indices_data[i];
|
||||
// indices can be negative values
|
||||
if (idx >= -axis_size && idx < axis_size) {
|
||||
indices_data_vec[i] = idx >= 0 ? idx : idx + static_cast<Tind>(axis_size);
|
||||
} else {
|
||||
ORT_THROW("indices element out of data bounds, idx=", idx,
|
||||
" must be within the inclusive range [", -axis_size,
|
||||
",", axis_size - 1, "]");
|
||||
}
|
||||
}
|
||||
|
||||
// copy input data into output
|
||||
int64_t input_size = DLTensorSize(input);
|
||||
memcpy(static_cast<char*>(output->data) + output->byte_offset,
|
||||
static_cast<char*>(input->data) + input->byte_offset,
|
||||
input_size * input->dtype.bits / 8);
|
||||
|
||||
std::vector<int64_t> input_strides;
|
||||
GetStrides(input->shape, num_dims, input_strides);
|
||||
|
||||
T* output_data = reinterpret_cast<T*>(static_cast<char*>(output->data) + output->byte_offset);
|
||||
T* updates_data = reinterpret_cast<T*>(static_cast<char*>(updates->data) + updates->byte_offset);
|
||||
const std::vector<int> indices_shape(indices->shape, indices->shape + num_dims);
|
||||
// Because indices data is flat, running_indices maintains indices's original dimensions.
|
||||
// We will use its dimensions to compute the corresponding index (or offset) to output_data,
|
||||
// which is also flat.
|
||||
std::vector<int64_t> running_indices(num_dims, 0);
|
||||
|
||||
for (int64_t i = 0; i < indices_size; i++) {
|
||||
Tind idx = indices_data_vec[i];
|
||||
// output indices come from running_indices
|
||||
std::vector<int64_t> curr_output_indices = running_indices;
|
||||
curr_output_indices[axis] = static_cast<int64_t>(idx);
|
||||
|
||||
// get the index into output_data
|
||||
int64_t output_idx = 0;
|
||||
for (int j = 0; j < num_dims; j++) {
|
||||
output_idx += curr_output_indices[j] * input_strides[j];
|
||||
}
|
||||
|
||||
// update data
|
||||
output_data[output_idx] = updates_data[i];
|
||||
|
||||
// update running_indices
|
||||
Tind carry = 1;
|
||||
for (int j = num_dims - 1; j >= 0; j--) {
|
||||
if (carry == 0) break;
|
||||
Tind curr_idx = running_indices[j] + carry;
|
||||
running_indices[j] = curr_idx % indices_shape[j];
|
||||
carry = curr_idx / indices_shape[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_EXTERN_SCATTER(input_type, index_type) \
|
||||
TVM_REGISTER_GLOBAL(GET_EXTERN_SCATTER_STR(input_type, index_type)) \
|
||||
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { \
|
||||
ScatterCommon<input_type, index_type>(args, ret); \
|
||||
});
|
||||
|
||||
#define REGISTER_EXTERN_SCATTER_PAIR(input_type) \
|
||||
REGISTER_EXTERN_SCATTER(input_type, int32_t) \
|
||||
REGISTER_EXTERN_SCATTER(input_type, int64_t) \
|
||||
|
||||
REGISTER_EXTERN_SCATTER_PAIR(bool)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(int8_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(uint8_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(int16_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(uint16_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(int32_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(uint32_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(int64_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(uint64_t)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(float)
|
||||
REGISTER_EXTERN_SCATTER_PAIR(double)
|
||||
|
||||
#undef REGISTER_EXTERN_SCATTER
|
||||
|
||||
static tvm::Tensor MakeExternScatter(const tvm::Tensor& t,
|
||||
int64_t axis_p,
|
||||
const tvm::Tensor& indices,
|
||||
const tvm::Tensor& updates,
|
||||
const std::string& name,
|
||||
const char* extern_scatter) {
|
||||
// handle negative axis
|
||||
int64_t input_rank = static_cast<int64_t>(t->shape.size());
|
||||
DCHECK(input_rank >= 1);
|
||||
DCHECK(input_rank == static_cast<int64_t>(indices->shape.size()));
|
||||
DCHECK(input_rank == static_cast<int64_t>(updates->shape.size()));
|
||||
int axis = static_cast<int>(tvm_codegen::HandleNegativeAxis(axis_p, input_rank));
|
||||
|
||||
// output has the same shape as input
|
||||
tvm::Array<tvm::Expr> output_shape;
|
||||
for (int64_t i = 0; i < input_rank; i++) {
|
||||
output_shape.push_back(t->shape[i]);
|
||||
}
|
||||
|
||||
return topi::detail::make_extern(
|
||||
/*output_shapes*/ {output_shape},
|
||||
/*output_types*/ {t->dtype},
|
||||
/*inputs*/ {t, indices, updates},
|
||||
[&](tvm::Array<tvm::Buffer> ins, tvm::Array<tvm::Buffer> outs) {
|
||||
tvm::Array<tvm::Expr> args = {tvm::Expr(extern_scatter),
|
||||
topi::detail::pack_buffer(ins[0]),
|
||||
topi::detail::pack_buffer(ins[1]),
|
||||
topi::detail::pack_buffer(ins[2]),
|
||||
topi::detail::pack_buffer(outs[0]),
|
||||
axis};
|
||||
return topi::detail::call_packed(args);
|
||||
},
|
||||
name, /*tag*/ "", /*attrs*/ {})[0];
|
||||
}
|
||||
|
||||
tvm::Tensor Scatter(const tvm::Tensor& t,
|
||||
int64_t axis_p,
|
||||
const tvm::Tensor& indices,
|
||||
const tvm::Tensor& updates,
|
||||
const std::string& name) {
|
||||
|
||||
#define MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, index_tensor_type, input_type, index_type) \
|
||||
if (t->dtype == input_tensor_type && indices->dtype == index_tensor_type) { \
|
||||
return MakeExternScatter(t, axis_p, indices, updates, name, GET_EXTERN_SCATTER_STR(input_type, index_type)); \
|
||||
}
|
||||
|
||||
#define MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(input_tensor_type, input_type) \
|
||||
MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(32), input_type, int32_t) \
|
||||
MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(64), input_type, int64_t)
|
||||
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Bool(), bool)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(8), int8_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(8), uint8_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(16), int16_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(16), uint16_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(32), int32_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(32), uint32_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(64), int64_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(64), uint64_t)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(32), float)
|
||||
MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(64), double)
|
||||
|
||||
#undef MAKE_EXTERN_SCATTER_PAIR_IF_MATCH
|
||||
#undef MAKE_EXTERN_SCATTER_IF_MATCH
|
||||
|
||||
ORT_NOT_IMPLEMENTED("input type is not implementated");
|
||||
}
|
||||
|
||||
#undef STRINGIFY_1
|
||||
#undef STRINGIFY
|
||||
#undef GET_EXTERN_SCATTER_STR
|
||||
|
||||
} // namespace tvm_codegen
|
||||
} // namespace onnxruntime
|
||||
18
onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h
Normal file
18
onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <tvm/tvm.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace nuphar {
|
||||
|
||||
tvm::Tensor Scatter(const tvm::Tensor& t,
|
||||
int64_t axis,
|
||||
const tvm::Tensor& indices,
|
||||
const tvm::Tensor& updates,
|
||||
const std::string& name = "scatter");
|
||||
|
||||
} // namespace nuphar
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -394,6 +394,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements);
|
||||
|
||||
static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
|
||||
#define NUPHAR_OP(name, ver, types) \
|
||||
|
|
@ -414,6 +416,8 @@ static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements)>());
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRegistry> NupharExecutionProvider::GetKernelRegistryInternal() const {
|
||||
|
|
|
|||
|
|
@ -206,5 +206,21 @@ TEST(Scatter, ValidNegativeIndex) {
|
|||
scatter_valid_negative_index("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_bool_with_axis_tests(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
||||
test.AddInput<bool>("data", {1, 5}, {false, false, false, true, false});
|
||||
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
|
||||
test.AddInput<bool>("updates", {1, 2}, {true, false});
|
||||
test.AddOutput<bool>("y", {1, 5}, {false, true, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Scatter, BoolInputWithAxis) {
|
||||
scatter_bool_with_axis_tests("Scatter", 9);
|
||||
scatter_bool_with_axis_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue