From 0ccfe6c86a144248b9ec0d6a78efa42bfe37cc16 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 30 Mar 2021 11:02:24 -0700 Subject: [PATCH] Enable type reduction for Scatter/ScatterElements CPU kernels (#7171) Enable type reduction for Scatter/ScatterElements CPU kernels. Some refactoring to reduce binary size. Add MLTypeCallDispatcher methods. Minor cleanup for Pad CPU kernel. --- .../core/framework/data_types_internal.h | 54 ++++++- onnxruntime/core/providers/cpu/tensor/pad.cc | 14 +- .../core/providers/cpu/tensor/scatter.cc | 144 ++++++++++++------ .../operator_type_usage_processors.py | 4 +- 4 files changed, 163 insertions(+), 53 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index b81ed85bf7..7ad932b2b8 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -350,10 +350,12 @@ class MLTypeCallDispatcher { } /** - * Invokes Fn<..., T> with leading template arguments and the specified arguments. + * Invokes Fn<..., T> with leading template arguments and the specified + * arguments. * * @tparam Fn The function object template. - * @tparam LeadingTemplateArgTypeList A type list of the leading template arguments. + * @tparam LeadingTemplateArgTypeList A type list of the leading template + * arguments. * @tparam Args The argument types. */ template