From 15138908e7ff72b542f681ad46e6bddc242e8a49 Mon Sep 17 00:00:00 2001 From: Yang Chen <40417152+yangchen-MS@users.noreply.github.com> Date: Thu, 3 Oct 2019 14:58:10 -0700 Subject: [PATCH] Yanchen/nuphar/scatter elems (#1992) * Added Scatter and ScatterElements to Nuphar Implemented Scatter (op_ver 9 - 10) and ScatterElements (op_ver 11) nuphar. Because TVM's compute is output-oriented, our current implementation uses extern calls for simplicity. * fixed build issue after rebase * remove dead code * Address CR * removed dead code * use GetAttrOrDefault * Address more CR feedback * add GetStrides to codegen/common/utils.h * added a unit test for Bool input data --- onnxruntime/core/codegen/common/utils.cc | 12 ++ onnxruntime/core/codegen/common/utils.h | 2 + .../compiler/x86/op_ir_creator/all_ops.h | 2 + .../x86/op_ir_creator/tensor/scatter.cc | 48 +++++ onnxruntime/core/providers/nuphar/kernel.cc | 23 ++ .../nuphar/mti_x86/tensor/scatter.cc | 202 ++++++++++++++++++ .../providers/nuphar/mti_x86/tensor/scatter.h | 18 ++ .../nuphar/nuphar_execution_provider.cc | 4 + .../providers/cpu/tensor/scatter_op_test.cc | 16 ++ 9 files changed, 327 insertions(+) create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/scatter.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h diff --git a/onnxruntime/core/codegen/common/utils.cc b/onnxruntime/core/codegen/common/utils.cc index 1c49a3a073..dfa3abf06a 100644 --- a/onnxruntime/core/codegen/common/utils.cc +++ b/onnxruntime/core/codegen/common/utils.cc @@ -52,4 +52,16 @@ int64_t TotalSize(const std::vector& shape) { return total; } +// Return the strides for the input shape, i.e. the number of +// elements contained by a single element of current dimension. +// For example, for shape[3][4][5][6], strides will be +// [4*5*6, 5*6, 6, 1], i.e. [120, 30, 6, 1] +void GetStrides(const int64_t* shape, int ndim, std::vector& strides) { + strides.resize(ndim); + strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; i--) { + strides[i] = strides[i+1] * shape[i+1]; + } +} + } // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/utils.h b/onnxruntime/core/codegen/common/utils.h index 40f300888d..d85df1d01e 100644 --- a/onnxruntime/core/codegen/common/utils.h +++ b/onnxruntime/core/codegen/common/utils.h @@ -17,4 +17,6 @@ bool IsEnvVarDefined(const char* var); int64_t TotalSize(const std::vector& shape); +void GetStrides(const int64_t* shape, int ndim, std::vector& strides); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h index 094e30ffa5..6f7304908e 100644 --- a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h @@ -52,6 +52,8 @@ namespace nuphar { ADD_OP_ITEM(MatMul) \ ADD_OP_ITEM(MatMulInteger) \ ADD_OP_ITEM(MatMulInteger16) \ + ADD_OP_ITEM(Scatter) \ + ADD_OP_ITEM(ScatterElements) \ ADD_OP_ITEM(Slice) \ ADD_OP_ITEM(Softmax) \ ADD_OP_ITEM(Tile) diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/scatter.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/scatter.cc new file mode 100644 index 0000000000..c33afb3bbc --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/scatter.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/framework/op_kernel_info.h" +#include "core/providers/nuphar/mti_x86/tensor/scatter.h" + +namespace onnxruntime { +namespace nuphar { + +static Status ScatterCommon( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext&, + tvm::Array& outputs, + const std::string& name) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper attrs(&ctx); + + // The default value of optional attribute axis is 0 + int64_t axis = attrs.GetAttrOrDefault("axis", 0); + + tvm::Tensor Y = Scatter(inputs[0], axis, inputs[1], inputs[2], name); + outputs.push_back(Y); + return Status::OK(); +} + +// Evaluate of Scatter OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Scatter)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& codegen_ctx, + tvm::Array& outputs) { + return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_ScatterElements"); +} + +// Evaluate of ScatterElements OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(ScatterElements)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& codegen_ctx, + tvm::Array& outputs) { + return ScatterCommon(inputs, node, codegen_ctx, outputs, node.Name() + "_Scatter"); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/kernel.cc b/onnxruntime/core/providers/nuphar/kernel.cc index 2c2997d517..6b22a70217 100644 --- a/onnxruntime/core/providers/nuphar/kernel.cc +++ b/onnxruntime/core/providers/nuphar/kernel.cc @@ -227,4 +227,27 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), nuphar::NupharKernel); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Scatter, + kOnnxDomain, + 9, + 10, + kNupharExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + nuphar::NupharKernel); + +ONNX_OPERATOR_KERNEL_EX( + ScatterElements, + kOnnxDomain, + 11, + kNupharExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + nuphar::NupharKernel); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc b/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc new file mode 100644 index 0000000000..5bf0ecca76 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/mti_x86/tensor/scatter.h" + +#include "core/codegen/common/utils.h" +#include "core/codegen/mti/mti_tvm_utils.h" +#include "core/common/common.h" +#include + +namespace onnxruntime { +namespace nuphar { + +#define STRINGIFY_1(x) #x +#define STRINGIFY(x) STRINGIFY_1(x) +#define GET_EXTERN_SCATTER_STR(input_type, index_type) \ + STRINGIFY(tvm.contrib.onnxruntime.scatter_##input_type##index_type) + +static int64_t DLTensorSize(const DLTensor* dl_tensor) { + int64_t sz = 1; + for (int i = 0; i < dl_tensor->ndim; ++i) { + sz *= dl_tensor->shape[i]; + } + return sz; +} + +template +void ScatterCommon(tvm::TVMArgs args, tvm::TVMRetValue* /*ret*/) { + DLTensor* input = args[0]; + DLTensor* indices = args[1]; + DLTensor* updates = args[2]; + DLTensor* output = args[3]; + int axis = args[4]; + + int num_dims = input->ndim; + DCHECK(axis < num_dims); + + for (int i = 0; i < num_dims; i++) { + if (indices->shape[i] != updates->shape[i]) { + ORT_THROW("Indices vs updates dimensions differs at position=", i, + " ", indices->shape[i], " vs ", updates->shape[i]); + } + } + + // extract indices from raw data + Tind* indices_data = reinterpret_cast(static_cast(indices->data) + indices->byte_offset); + int64_t indices_size = DLTensorSize(indices); + std::vector indices_data_vec(indices_size); + int64_t axis_size = input->shape[axis]; + for (int64_t i = 0; i < indices_size; i++) { + Tind idx = indices_data[i]; + // indices can be negative values + if (idx >= -axis_size && idx < axis_size) { + indices_data_vec[i] = idx >= 0 ? idx : idx + static_cast(axis_size); + } else { + ORT_THROW("indices element out of data bounds, idx=", idx, + " must be within the inclusive range [", -axis_size, + ",", axis_size - 1, "]"); + } + } + + // copy input data into output + int64_t input_size = DLTensorSize(input); + memcpy(static_cast(output->data) + output->byte_offset, + static_cast(input->data) + input->byte_offset, + input_size * input->dtype.bits / 8); + + std::vector input_strides; + GetStrides(input->shape, num_dims, input_strides); + + T* output_data = reinterpret_cast(static_cast(output->data) + output->byte_offset); + T* updates_data = reinterpret_cast(static_cast(updates->data) + updates->byte_offset); + const std::vector indices_shape(indices->shape, indices->shape + num_dims); + // Because indices data is flat, running_indices maintains indices's original dimensions. + // We will use its dimensions to compute the corresponding index (or offset) to output_data, + // which is also flat. + std::vector running_indices(num_dims, 0); + + for (int64_t i = 0; i < indices_size; i++) { + Tind idx = indices_data_vec[i]; + // output indices come from running_indices + std::vector curr_output_indices = running_indices; + curr_output_indices[axis] = static_cast(idx); + + // get the index into output_data + int64_t output_idx = 0; + for (int j = 0; j < num_dims; j++) { + output_idx += curr_output_indices[j] * input_strides[j]; + } + + // update data + output_data[output_idx] = updates_data[i]; + + // update running_indices + Tind carry = 1; + for (int j = num_dims - 1; j >= 0; j--) { + if (carry == 0) break; + Tind curr_idx = running_indices[j] + carry; + running_indices[j] = curr_idx % indices_shape[j]; + carry = curr_idx / indices_shape[j]; + } + } +} + +#define REGISTER_EXTERN_SCATTER(input_type, index_type) \ + TVM_REGISTER_GLOBAL(GET_EXTERN_SCATTER_STR(input_type, index_type)) \ + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { \ + ScatterCommon(args, ret); \ + }); + +#define REGISTER_EXTERN_SCATTER_PAIR(input_type) \ + REGISTER_EXTERN_SCATTER(input_type, int32_t) \ + REGISTER_EXTERN_SCATTER(input_type, int64_t) \ + +REGISTER_EXTERN_SCATTER_PAIR(bool) +REGISTER_EXTERN_SCATTER_PAIR(int8_t) +REGISTER_EXTERN_SCATTER_PAIR(uint8_t) +REGISTER_EXTERN_SCATTER_PAIR(int16_t) +REGISTER_EXTERN_SCATTER_PAIR(uint16_t) +REGISTER_EXTERN_SCATTER_PAIR(int32_t) +REGISTER_EXTERN_SCATTER_PAIR(uint32_t) +REGISTER_EXTERN_SCATTER_PAIR(int64_t) +REGISTER_EXTERN_SCATTER_PAIR(uint64_t) +REGISTER_EXTERN_SCATTER_PAIR(float) +REGISTER_EXTERN_SCATTER_PAIR(double) + +#undef REGISTER_EXTERN_SCATTER + +static tvm::Tensor MakeExternScatter(const tvm::Tensor& t, + int64_t axis_p, + const tvm::Tensor& indices, + const tvm::Tensor& updates, + const std::string& name, + const char* extern_scatter) { + // handle negative axis + int64_t input_rank = static_cast(t->shape.size()); + DCHECK(input_rank >= 1); + DCHECK(input_rank == static_cast(indices->shape.size())); + DCHECK(input_rank == static_cast(updates->shape.size())); + int axis = static_cast(tvm_codegen::HandleNegativeAxis(axis_p, input_rank)); + + // output has the same shape as input + tvm::Array output_shape; + for (int64_t i = 0; i < input_rank; i++) { + output_shape.push_back(t->shape[i]); + } + + return topi::detail::make_extern( + /*output_shapes*/ {output_shape}, + /*output_types*/ {t->dtype}, + /*inputs*/ {t, indices, updates}, + [&](tvm::Array ins, tvm::Array outs) { + tvm::Array args = {tvm::Expr(extern_scatter), + topi::detail::pack_buffer(ins[0]), + topi::detail::pack_buffer(ins[1]), + topi::detail::pack_buffer(ins[2]), + topi::detail::pack_buffer(outs[0]), + axis}; + return topi::detail::call_packed(args); + }, + name, /*tag*/ "", /*attrs*/ {})[0]; +} + +tvm::Tensor Scatter(const tvm::Tensor& t, + int64_t axis_p, + const tvm::Tensor& indices, + const tvm::Tensor& updates, + const std::string& name) { + +#define MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, index_tensor_type, input_type, index_type) \ + if (t->dtype == input_tensor_type && indices->dtype == index_tensor_type) { \ + return MakeExternScatter(t, axis_p, indices, updates, name, GET_EXTERN_SCATTER_STR(input_type, index_type)); \ + } + +#define MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(input_tensor_type, input_type) \ + MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(32), input_type, int32_t) \ + MAKE_EXTERN_SCATTER_IF_MATCH(input_tensor_type, tvm::Int(64), input_type, int64_t) + + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Bool(), bool) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(8), int8_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(8), uint8_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(16), int16_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(16), uint16_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(32), int32_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(32), uint32_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Int(64), int64_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::UInt(64), uint64_t) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(32), float) + MAKE_EXTERN_SCATTER_PAIR_IF_MATCH(tvm::Float(64), double) + +#undef MAKE_EXTERN_SCATTER_PAIR_IF_MATCH +#undef MAKE_EXTERN_SCATTER_IF_MATCH + + ORT_NOT_IMPLEMENTED("input type is not implementated"); +} + +#undef STRINGIFY_1 +#undef STRINGIFY +#undef GET_EXTERN_SCATTER_STR + +} // namespace tvm_codegen +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h b/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h new file mode 100644 index 0000000000..1955d9d092 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/mti_x86/tensor/scatter.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +namespace onnxruntime { +namespace nuphar { + +tvm::Tensor Scatter(const tvm::Tensor& t, + int64_t axis, + const tvm::Tensor& indices, + const tvm::Tensor& updates, + const std::string& name = "scatter"); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc index 1a16e3a68f..1d059d188a 100644 --- a/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc +++ b/onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc @@ -394,6 +394,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements); static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) { #define NUPHAR_OP(name, ver, types) \ @@ -414,6 +416,8 @@ static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); } std::shared_ptr NupharExecutionProvider::GetKernelRegistryInternal() const { diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 7dd0337074..d61c2163ea 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -206,5 +206,21 @@ TEST(Scatter, ValidNegativeIndex) { scatter_valid_negative_index("ScatterElements", 11); } +static void scatter_bool_with_axis_tests(const char* op_name, int op_version) { + OpTester test(op_name, op_version); + test.AddAttribute("axis", 1); + + test.AddInput("data", {1, 5}, {false, false, false, true, false}); + test.AddInput("indices", {1, 2}, {1, 3}); + test.AddInput("updates", {1, 2}, {true, false}); + test.AddOutput("y", {1, 5}, {false, true, false, false, false}); + test.Run(); +} + +TEST(Scatter, BoolInputWithAxis) { + scatter_bool_with_axis_tests("Scatter", 9); + scatter_bool_with_axis_tests("ScatterElements", 11); +} + } // namespace test } // namespace onnxruntime