mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Enable type reduction for Scatter/ScatterElements CPU kernels (#7171)
Enable type reduction for Scatter/ScatterElements CPU kernels. Some refactoring to reduce binary size. Add MLTypeCallDispatcher methods. Minor cleanup for Pad CPU kernel.
This commit is contained in:
parent
07201bac7a
commit
0ccfe6c86a
4 changed files with 163 additions and 53 deletions
|
|
@ -350,10 +350,12 @@ class MLTypeCallDispatcher {
|
|||
}
|
||||
|
||||
/**
|
||||
* Invokes Fn<..., T> with leading template arguments and the specified arguments.
|
||||
* Invokes Fn<..., T> with leading template arguments and the specified
|
||||
* arguments.
|
||||
*
|
||||
* @tparam Fn The function object template.
|
||||
* @tparam LeadingTemplateArgTypeList A type list of the leading template arguments.
|
||||
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
||||
* arguments.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
|
||||
|
|
@ -403,11 +405,55 @@ class MLTypeCallDispatcher {
|
|||
*/
|
||||
template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
|
||||
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
|
||||
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
|
||||
Ret, Fn, UnsupportedPolicy, TypeList<>>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes Fn<..., T> with leading template arguments and the specified
|
||||
* arguments and returns the result.
|
||||
*
|
||||
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
||||
* @tparam Fn The function object template.
|
||||
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
||||
* arguments.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
|
||||
Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
|
||||
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
|
||||
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes Fn<..., T> with leading template arguments and the specified
|
||||
* arguments and returns the result.
|
||||
*
|
||||
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
||||
* @tparam Fn The function object template.
|
||||
* @tparam UnsupportedPolicy The policy used to handle unsupported types.
|
||||
* See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
|
||||
* for an example.
|
||||
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
||||
* arguments.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <class Ret,
|
||||
template <typename...> class Fn,
|
||||
class UnsupportedPolicy,
|
||||
typename LeadingTemplateArgTypeList,
|
||||
typename... Args>
|
||||
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args&&... args) const {
|
||||
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
|
||||
|
||||
// call helper.Invoke() with Fn<T> for each T in Types
|
||||
// given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
|
||||
// call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
|
||||
static_cast<void>(std::array<int, sizeof...(Types)>{
|
||||
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});
|
||||
helper.template Invoke<Types>(
|
||||
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
|
||||
std::forward<Args>(args)...)...});
|
||||
|
||||
// avoid "unused parameter" warning for the case where Types is empty
|
||||
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
|
||||
|
|
|
|||
|
|
@ -517,15 +517,21 @@ Status Pad::Compute(OpKernelContext* ctx) const {
|
|||
slices_to_use = &slices_;
|
||||
}
|
||||
|
||||
Status pad_status{};
|
||||
switch (element_size) {
|
||||
case sizeof(uint32_t):
|
||||
return PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
|
||||
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
return PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
|
||||
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
return PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
|
||||
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
|
||||
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
|
||||
break;
|
||||
}
|
||||
return pad_status;
|
||||
}
|
||||
}; // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,15 +2,41 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
//https://github.com/onnx/onnx/blob/master/docs/Operators.md#Scatter
|
||||
#include <type_traits>
|
||||
|
||||
#include "gsl/gsl"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/element_type_lists.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/providers/op_kernel_type_control_utils.h"
|
||||
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
|
||||
#include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0, element_type_lists::All);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0, element_type_lists::All);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
using ScatterDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0);
|
||||
using EnabledScatterDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0);
|
||||
|
||||
using ScatterElementsDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0);
|
||||
using EnabledScatterElementsDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0);
|
||||
|
||||
template <typename EnabledDataTypes>
|
||||
class Scatter final : public OpKernel {
|
||||
public:
|
||||
explicit Scatter(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -30,9 +56,11 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
9, 10,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Scatter);
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<ScatterDataTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterDataTypes>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
Scatter<EnabledScatterDataTypes>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ScatterElements,
|
||||
|
|
@ -40,18 +68,22 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
12,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Scatter);
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<ScatterElementsDataTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
Scatter<EnabledScatterElementsDataTypes>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ScatterElements,
|
||||
13,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Scatter);
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<ScatterElementsDataTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
Scatter<EnabledScatterElementsDataTypes>);
|
||||
|
||||
template <class T>
|
||||
struct Func_Assignment {
|
||||
|
|
@ -60,20 +92,20 @@ struct Func_Assignment {
|
|||
}
|
||||
};
|
||||
|
||||
template <class Tin, class Tdata, typename FuncT>
|
||||
Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
|
||||
const int64_t axis, Tensor* data_output) {
|
||||
const TensorShape& input_data_shape = data_input->Shape();
|
||||
const Tin* indices_data_raw = indices_input->template Data<Tin>();
|
||||
const auto num_indices = indices_input->Shape().Size();
|
||||
template <class TIndex>
|
||||
Status GetIndices(
|
||||
const Tensor& data_input, const Tensor& indices_input, int64_t axis,
|
||||
std::vector<int64_t>& indices_data) {
|
||||
const auto& input_data_shape = data_input.Shape();
|
||||
const auto* indices_data_raw = indices_input.template Data<TIndex>();
|
||||
const auto num_indices = indices_input.Shape().Size();
|
||||
const auto axis_dim_limit = input_data_shape[axis];
|
||||
|
||||
std::vector<Tin> indices_data;
|
||||
indices_data.reserve(num_indices);
|
||||
|
||||
auto axis_dim_limit = input_data_shape[axis];
|
||||
std::vector<int64_t> indices_data_result;
|
||||
indices_data_result.reserve(num_indices);
|
||||
|
||||
for (int64_t i = 0; i < num_indices; ++i) {
|
||||
Tin idx = indices_data_raw[i];
|
||||
const int64_t idx = static_cast<int64_t>(indices_data_raw[i]);
|
||||
|
||||
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
@ -82,20 +114,32 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
|
|||
",", axis_dim_limit - 1, "]");
|
||||
}
|
||||
|
||||
indices_data.push_back(idx < 0 ? idx + static_cast<Tin>(axis_dim_limit) : idx);
|
||||
indices_data_result.push_back(idx < 0 ? idx + axis_dim_limit : idx);
|
||||
}
|
||||
|
||||
indices_data = std::move(indices_data_result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class Tdata, typename FuncT>
|
||||
Status CopyScatterData(
|
||||
const FuncT& func,
|
||||
const Tensor* data_input, const std::vector<int64_t>& indices_data, const Tensor* updates_input, int64_t axis,
|
||||
Tensor* data_output) {
|
||||
const TensorShape& input_data_shape = data_input->Shape();
|
||||
|
||||
const auto input_elements = input_data_shape.Size();
|
||||
const auto total_input_bytes = data_input->SizeInBytes();
|
||||
|
||||
const auto num_indices = gsl::narrow<int64_t>(indices_data.size());
|
||||
|
||||
const auto* src_base = static_cast<const Tdata*>(data_input->DataRaw());
|
||||
auto* dst_base = static_cast<Tdata*>(data_output->MutableDataRaw());
|
||||
const bool is_string_type = data_input->IsDataTypeString();
|
||||
|
||||
// We allow runtime to re-use input for output. If input/output Tensor* are the same
|
||||
// we do not copy
|
||||
if (src_base != dst_base) {
|
||||
if (is_string_type) {
|
||||
if (std::is_same<Tdata, std::string>::value) {
|
||||
const auto* str_begin = data_input->template Data<std::string>();
|
||||
const std::string* str_end = str_begin + input_elements;
|
||||
auto* dst = data_output->template MutableData<std::string>();
|
||||
|
|
@ -153,7 +197,7 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
|
|||
const auto* update_data = static_cast<const Tdata*>(updates_input->DataRaw());
|
||||
// For every update we compute the destination offset and copy it there
|
||||
for (int64_t index = 0; index < num_indices;) {
|
||||
const Tin axis_idx = indices_data[index];
|
||||
const auto axis_idx = indices_data[index];
|
||||
|
||||
// Compute the offset
|
||||
// See comments above for dim_block_size
|
||||
|
|
@ -189,17 +233,17 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class T, class... Args>
|
||||
inline Status CopyInt32Index(Args&&... args) {
|
||||
return CopyScatterData<int32_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename TData>
|
||||
struct CopyScatterDataDispatchTarget {
|
||||
Status operator()(const Tensor* data_input, const std::vector<int64_t>& indices_data, const Tensor* updates_input, int64_t axis,
|
||||
Tensor* data_output) const {
|
||||
return CopyScatterData<TData>(
|
||||
Func_Assignment<TData>(), data_input, indices_data, updates_input, axis, data_output);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, class... Args>
|
||||
inline Status CopyInt64Index(Args&&... args) {
|
||||
return CopyScatterData<int64_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
Status Scatter::Compute(OpKernelContext* context) const {
|
||||
template <typename EnabledDataTypes>
|
||||
Status Scatter<EnabledDataTypes>::Compute(OpKernelContext* context) const {
|
||||
const auto* data_input = context->Input<Tensor>(0);
|
||||
const auto& input_data_shape = data_input->Shape();
|
||||
const auto axis = HandleNegativeAxis(axis_, input_data_shape.NumDimensions());
|
||||
|
|
@ -241,17 +285,29 @@ Status Scatter::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
auto* data_output = context->Output(0, input_data_shape);
|
||||
Status status{};
|
||||
const auto index_type = indices_input->GetElementType();
|
||||
std::vector<int64_t> indices_data{};
|
||||
|
||||
MLDataType Tdata_type = data_input->DataType();
|
||||
Status status;
|
||||
if (indices_input->IsDataType<int32_t>()) {
|
||||
DispatchOnTensorTypeWithReturn(Tdata_type, status, CopyInt32Index, data_input, indices_input, updates_input, axis, data_output);
|
||||
} else if (indices_input->IsDataType<int64_t>()) {
|
||||
DispatchOnTensorTypeWithReturn(Tdata_type, status, CopyInt64Index, data_input, indices_input, updates_input, axis, data_output);
|
||||
if (index_type == utils::ToTensorProtoElementType<int32_t>()) {
|
||||
status = GetIndices<int32_t>(*data_input, *indices_input, axis, indices_data);
|
||||
} else if (index_type == utils::ToTensorProtoElementType<int64_t>()) {
|
||||
status = GetIndices<int64_t>(*data_input, *indices_input, axis, indices_data);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Indices type is not supported.");
|
||||
}
|
||||
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto* data_output = context->Output(0, input_data_shape);
|
||||
const auto data_type = data_input->GetElementType();
|
||||
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{data_type};
|
||||
status = dispatcher.template InvokeRet<Status, CopyScatterDataDispatchTarget>(
|
||||
data_input, indices_data, updates_input, axis, data_output);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
@ -269,7 +325,9 @@ struct Func_Add {
|
|||
template <class Tin, class Tdata>
|
||||
Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input,
|
||||
const int64_t axis, Tensor* data_output) {
|
||||
return CopyScatterData<Tin, Tdata>(Func_Add<Tdata>(), data_output, indices_input, updates_input, axis, data_output);
|
||||
std::vector<int64_t> indices_data{};
|
||||
ORT_RETURN_IF_ERROR(GetIndices<Tin>(*data_output, *indices_input, axis, indices_data));
|
||||
return CopyScatterData<Tdata>(Func_Add<Tdata>(), data_output, indices_data, updates_input, axis, data_output);
|
||||
}
|
||||
|
||||
#define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \
|
||||
|
|
|
|||
|
|
@ -337,8 +337,8 @@ def _create_operator_type_usage_processors():
|
|||
'Range', 'Reciprocal', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp',
|
||||
'ReduceMax', 'ReduceMean', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'ReduceSumSquare',
|
||||
'Relu', 'Resize', 'ReverseSequence', 'RoiAlign', 'Round',
|
||||
'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin', 'Softmax', 'Split',
|
||||
'SplitToSequence', 'Sqrt', 'Sum',
|
||||
'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin',
|
||||
'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum',
|
||||
'Tanh', 'TopK', 'Transpose',
|
||||
'Unique',
|
||||
'Where']
|
||||
|
|
|
|||
Loading…
Reference in a new issue