diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index 8407c7f42a..3c7ccd0652 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -34,44 +34,60 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType()}), ScatterElements); -#define TYPED_FUNCTION_CALL(T) \ - if (utils::IsPrimitiveDataType(T_type)) { \ - T* output_data = output_tensor->template MutableData(); \ - const T* input_data = data_tensor->template Data(); \ - const T* update_data = updates_tensor->template Data(); \ - if (utils::IsPrimitiveDataType(Tin_type)) { \ - const int32_t* indices_data = indices_tensor->template Data(); \ - return ScatterElementsImpl( \ - rank, \ - reinterpret_cast::MappedType*>(input_data), \ - input_data_size, \ - buffer_input_dims, \ - buffer_input_strides, \ - indices_data, \ - indices_size, \ - buffer_indices_dims, \ - fdm_indices_strides, \ - reinterpret_cast::MappedType*>(update_data), \ - axis, \ - reinterpret_cast::MappedType*>(output_data)); \ - } \ - if (utils::IsPrimitiveDataType(Tin_type)) { \ - const int64_t* indices_data = indices_tensor->template Data(); \ - return ScatterElementsImpl( \ - rank, \ - reinterpret_cast::MappedType*>(input_data), \ - input_data_size, \ - buffer_input_dims, \ - buffer_input_strides, \ - indices_data, \ - indices_size, \ - buffer_indices_dims, \ - fdm_indices_strides, \ - reinterpret_cast::MappedType*>(update_data), \ - axis, \ - reinterpret_cast::MappedType*>(output_data)); \ - } \ +template +struct ScatterElements::ComputeImpl { + Status operator()(const Tensor* data_tensor, + const Tensor* updates_tensor, + const Tensor* indices_tensor, + Tensor* output_tensor, + const int rank, + const int64_t input_data_size, + TArray& buffer_input_dims, + TArray& buffer_input_strides, + const int64_t indices_size, + TArray& buffer_indices_dims, + TArray& fdm_indices_strides, + const int axis) const { + T* output_data = output_tensor->template MutableData(); + const T* input_data = data_tensor->template Data(); + const T* update_data = updates_tensor->template Data(); + typedef typename ToCudaType::MappedType CudaT; + MLDataType Tin_type = indices_tensor->DataType(); + if (utils::IsPrimitiveDataType(Tin_type)) { + const int32_t* indices_data = indices_tensor->template Data(); + return ScatterElementsImpl( + rank, + reinterpret_cast(input_data), + input_data_size, + buffer_input_dims, + buffer_input_strides, + indices_data, + indices_size, + buffer_indices_dims, + fdm_indices_strides, + reinterpret_cast(update_data), + axis, + reinterpret_cast(output_data)); + } else if (utils::IsPrimitiveDataType(Tin_type)) { + const int64_t* indices_data = indices_tensor->template Data(); + return ScatterElementsImpl( + rank, + reinterpret_cast(input_data), + input_data_size, + buffer_input_dims, + buffer_input_strides, + indices_data, + indices_size, + buffer_indices_dims, + fdm_indices_strides, + reinterpret_cast(update_data), + axis, + reinterpret_cast(output_data)); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements."); } +}; Status ScatterElements::ComputeInternal(OpKernelContext* context) const { const auto* data_tensor = context->Input(0); @@ -131,23 +147,12 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { fdm_indices_strides[i] = fast_divmod(static_cast(indices_strides[i])); } - MLDataType Tin_type = indices_tensor->DataType(); - MLDataType T_type = data_tensor->DataType(); - - TYPED_FUNCTION_CALL(float) - TYPED_FUNCTION_CALL(MLFloat16) - TYPED_FUNCTION_CALL(int16_t) - TYPED_FUNCTION_CALL(int8_t) - TYPED_FUNCTION_CALL(int32_t) - TYPED_FUNCTION_CALL(int64_t) - TYPED_FUNCTION_CALL(uint8_t) - TYPED_FUNCTION_CALL(uint16_t) - TYPED_FUNCTION_CALL(uint32_t) - TYPED_FUNCTION_CALL(uint64_t) - TYPED_FUNCTION_CALL(double) - TYPED_FUNCTION_CALL(bool) - - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T is not supported yet in ScatterElements."); + utils::MLTypeCallDispatcherRet + t_disp(data_tensor->GetElementType()); + return t_disp.Invoke(data_tensor, updates_tensor, indices_tensor, output_tensor, rank, + input_data_size, buffer_input_dims, buffer_input_strides, indices_size, + buffer_indices_dims, fdm_indices_strides, axis); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index f70bf6b778..30e6435299 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -19,9 +19,11 @@ class ScatterElements final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; private: + template + struct ComputeImpl; + int64_t axis_; }; } // namespace cuda } // namespace onnxruntime -