Enable type reduction for Scatter/ScatterElements CPU kernels (#7171)

Enable type reduction for Scatter/ScatterElements CPU kernels. Some refactoring to reduce binary size.
Add MLTypeCallDispatcher methods.
Minor cleanup for Pad CPU kernel.
This commit is contained in:
Edward Chen 2021-03-30 11:02:24 -07:00 committed by GitHub
parent 07201bac7a
commit 0ccfe6c86a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 53 deletions

View file

@ -350,10 +350,12 @@ class MLTypeCallDispatcher {
}
/**
* Invokes Fn<..., T> with leading template arguments and the specified arguments.
* Invokes Fn<..., T> with leading template arguments and the specified
* arguments.
*
* @tparam Fn The function object template.
* @tparam LeadingTemplateArgTypeList A type list of the leading template arguments.
* @tparam LeadingTemplateArgTypeList A type list of the leading template
* arguments.
* @tparam Args The argument types.
*/
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
@ -403,11 +405,55 @@ class MLTypeCallDispatcher {
*/
template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
Ret, Fn, UnsupportedPolicy, TypeList<>>(
std::forward<Args>(args)...);
}
/**
* Invokes Fn<..., T> with leading template arguments and the specified
* arguments and returns the result.
*
* @tparam Ret The return type. Fn should return a type convertible to Ret.
* @tparam Fn The function object template.
* @tparam LeadingTemplateArgTypeList A type list of the leading template
* arguments.
* @tparam Args The argument types.
*/
template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
std::forward<Args>(args)...);
}
/**
* Invokes Fn<..., T> with leading template arguments and the specified
* arguments and returns the result.
*
* @tparam Ret The return type. Fn should return a type convertible to Ret.
* @tparam Fn The function object template.
* @tparam UnsupportedPolicy The policy used to handle unsupported types.
* See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
* for an example.
* @tparam LeadingTemplateArgTypeList A type list of the leading template
* arguments.
* @tparam Args The argument types.
*/
template <class Ret,
template <typename...> class Fn,
class UnsupportedPolicy,
typename LeadingTemplateArgTypeList,
typename... Args>
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
// call helper.Invoke() with Fn<T> for each T in Types
// given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
// call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
static_cast<void>(std::array<int, sizeof...(Types)>{
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});
helper.template Invoke<Types>(
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
std::forward<Args>(args)...)...});
// avoid "unused parameter" warning for the case where Types is empty
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});

View file

@ -517,15 +517,21 @@ Status Pad::Compute(OpKernelContext* ctx) const {
slices_to_use = &slices_;
}
Status pad_status{};
switch (element_size) {
case sizeof(uint32_t):
return PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
break;
case sizeof(uint64_t):
return PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
break;
case sizeof(uint8_t):
return PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
break;
}
return pad_status;
}
}; // namespace onnxruntime

View file

@ -2,15 +2,41 @@
// Licensed under the MIT License.
//https://github.com/onnx/onnx/blob/master/docs/Operators.md#Scatter
#include <type_traits>
#include "gsl/gsl"
#include "core/common/common.h"
#include "core/framework/element_type_lists.h"
#include "core/framework/op_kernel.h"
#include "core/providers/common.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h"
#endif
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0, element_type_lists::All);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0, element_type_lists::All);
} // namespace op_kernel_type_control
using ScatterDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0);
using EnabledScatterDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Scatter, Input, 0);
using ScatterElementsDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0);
using EnabledScatterElementsDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ScatterElements, Input, 0);
template <typename EnabledDataTypes>
class Scatter final : public OpKernel {
public:
explicit Scatter(const OpKernelInfo& info) : OpKernel(info) {
@ -30,9 +56,11 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
9, 10,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<ScatterDataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledScatterDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
Scatter<EnabledScatterDataTypes>);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ScatterElements,
@ -40,18 +68,22 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
12,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<ScatterElementsDataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
Scatter<EnabledScatterElementsDataTypes>);
ONNX_CPU_OPERATOR_KERNEL(
ScatterElements,
13,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<ScatterElementsDataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
Scatter<EnabledScatterElementsDataTypes>);
template <class T>
struct Func_Assignment {
@ -60,20 +92,20 @@ struct Func_Assignment {
}
};
template <class Tin, class Tdata, typename FuncT>
Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
const TensorShape& input_data_shape = data_input->Shape();
const Tin* indices_data_raw = indices_input->template Data<Tin>();
const auto num_indices = indices_input->Shape().Size();
template <class TIndex>
Status GetIndices(
const Tensor& data_input, const Tensor& indices_input, int64_t axis,
std::vector<int64_t>& indices_data) {
const auto& input_data_shape = data_input.Shape();
const auto* indices_data_raw = indices_input.template Data<TIndex>();
const auto num_indices = indices_input.Shape().Size();
const auto axis_dim_limit = input_data_shape[axis];
std::vector<Tin> indices_data;
indices_data.reserve(num_indices);
auto axis_dim_limit = input_data_shape[axis];
std::vector<int64_t> indices_data_result;
indices_data_result.reserve(num_indices);
for (int64_t i = 0; i < num_indices; ++i) {
Tin idx = indices_data_raw[i];
const int64_t idx = static_cast<int64_t>(indices_data_raw[i]);
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@ -82,20 +114,32 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
",", axis_dim_limit - 1, "]");
}
indices_data.push_back(idx < 0 ? idx + static_cast<Tin>(axis_dim_limit) : idx);
indices_data_result.push_back(idx < 0 ? idx + axis_dim_limit : idx);
}
indices_data = std::move(indices_data_result);
return Status::OK();
}
template <class Tdata, typename FuncT>
Status CopyScatterData(
const FuncT& func,
const Tensor* data_input, const std::vector<int64_t>& indices_data, const Tensor* updates_input, int64_t axis,
Tensor* data_output) {
const TensorShape& input_data_shape = data_input->Shape();
const auto input_elements = input_data_shape.Size();
const auto total_input_bytes = data_input->SizeInBytes();
const auto num_indices = gsl::narrow<int64_t>(indices_data.size());
const auto* src_base = static_cast<const Tdata*>(data_input->DataRaw());
auto* dst_base = static_cast<Tdata*>(data_output->MutableDataRaw());
const bool is_string_type = data_input->IsDataTypeString();
// We allow runtime to re-use input for output. If input/output Tensor* are the same
// we do not copy
if (src_base != dst_base) {
if (is_string_type) {
if (std::is_same<Tdata, std::string>::value) {
const auto* str_begin = data_input->template Data<std::string>();
const std::string* str_end = str_begin + input_elements;
auto* dst = data_output->template MutableData<std::string>();
@ -153,7 +197,7 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
const auto* update_data = static_cast<const Tdata*>(updates_input->DataRaw());
// For every update we compute the destination offset and copy it there
for (int64_t index = 0; index < num_indices;) {
const Tin axis_idx = indices_data[index];
const auto axis_idx = indices_data[index];
// Compute the offset
// See comments above for dim_block_size
@ -189,17 +233,17 @@ Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor
return Status::OK();
}
template <class T, class... Args>
inline Status CopyInt32Index(Args&&... args) {
return CopyScatterData<int32_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
}
template <typename TData>
struct CopyScatterDataDispatchTarget {
Status operator()(const Tensor* data_input, const std::vector<int64_t>& indices_data, const Tensor* updates_input, int64_t axis,
Tensor* data_output) const {
return CopyScatterData<TData>(
Func_Assignment<TData>(), data_input, indices_data, updates_input, axis, data_output);
}
};
template <class T, class... Args>
inline Status CopyInt64Index(Args&&... args) {
return CopyScatterData<int64_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
}
Status Scatter::Compute(OpKernelContext* context) const {
template <typename EnabledDataTypes>
Status Scatter<EnabledDataTypes>::Compute(OpKernelContext* context) const {
const auto* data_input = context->Input<Tensor>(0);
const auto& input_data_shape = data_input->Shape();
const auto axis = HandleNegativeAxis(axis_, input_data_shape.NumDimensions());
@ -241,17 +285,29 @@ Status Scatter::Compute(OpKernelContext* context) const {
}
}
auto* data_output = context->Output(0, input_data_shape);
Status status{};
const auto index_type = indices_input->GetElementType();
std::vector<int64_t> indices_data{};
MLDataType Tdata_type = data_input->DataType();
Status status;
if (indices_input->IsDataType<int32_t>()) {
DispatchOnTensorTypeWithReturn(Tdata_type, status, CopyInt32Index, data_input, indices_input, updates_input, axis, data_output);
} else if (indices_input->IsDataType<int64_t>()) {
DispatchOnTensorTypeWithReturn(Tdata_type, status, CopyInt64Index, data_input, indices_input, updates_input, axis, data_output);
if (index_type == utils::ToTensorProtoElementType<int32_t>()) {
status = GetIndices<int32_t>(*data_input, *indices_input, axis, indices_data);
} else if (index_type == utils::ToTensorProtoElementType<int64_t>()) {
status = GetIndices<int64_t>(*data_input, *indices_input, axis, indices_data);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Indices type is not supported.");
}
if (!status.IsOK()) {
return status;
}
auto* data_output = context->Output(0, input_data_shape);
const auto data_type = data_input->GetElementType();
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{data_type};
status = dispatcher.template InvokeRet<Status, CopyScatterDataDispatchTarget>(
data_input, indices_data, updates_input, axis, data_output);
return status;
}
@ -269,7 +325,9 @@ struct Func_Add {
template <class Tin, class Tdata>
Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
return CopyScatterData<Tin, Tdata>(Func_Add<Tdata>(), data_output, indices_input, updates_input, axis, data_output);
std::vector<int64_t> indices_data{};
ORT_RETURN_IF_ERROR(GetIndices<Tin>(*data_output, *indices_input, axis, indices_data));
return CopyScatterData<Tdata>(Func_Add<Tdata>(), data_output, indices_data, updates_input, axis, data_output);
}
#define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \

View file

@ -337,8 +337,8 @@ def _create_operator_type_usage_processors():
'Range', 'Reciprocal', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp',
'ReduceMax', 'ReduceMean', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'ReduceSumSquare',
'Relu', 'Resize', 'ReverseSequence', 'RoiAlign', 'Round',
'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin', 'Softmax', 'Split',
'SplitToSequence', 'Sqrt', 'Sum',
'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin',
'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum',
'Tanh', 'TopK', 'Transpose',
'Unique',
'Where']