Refactor - ScatterElements (#3559)

Refactor ScatterElements using MLTypeCallDispatcherRet to refactor
This commit is contained in:
pengwa 2020-04-21 08:58:42 +08:00 committed by GitHub
parent 2579a72a88
commit e233e6ba45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 55 deletions

View file

@ -34,44 +34,60 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);
#define TYPED_FUNCTION_CALL(T) \
if (utils::IsPrimitiveDataType<T>(T_type)) { \
T* output_data = output_tensor->template MutableData<T>(); \
const T* input_data = data_tensor->template Data<T>(); \
const T* update_data = updates_tensor->template Data<T>(); \
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) { \
const int32_t* indices_data = indices_tensor->template Data<int32_t>(); \
return ScatterElementsImpl( \
rank, \
reinterpret_cast<const ToCudaType<T>::MappedType*>(input_data), \
input_data_size, \
buffer_input_dims, \
buffer_input_strides, \
indices_data, \
indices_size, \
buffer_indices_dims, \
fdm_indices_strides, \
reinterpret_cast<const ToCudaType<T>::MappedType*>(update_data), \
axis, \
reinterpret_cast<ToCudaType<T>::MappedType*>(output_data)); \
} \
if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) { \
const int64_t* indices_data = indices_tensor->template Data<int64_t>(); \
return ScatterElementsImpl( \
rank, \
reinterpret_cast<const ToCudaType<T>::MappedType*>(input_data), \
input_data_size, \
buffer_input_dims, \
buffer_input_strides, \
indices_data, \
indices_size, \
buffer_indices_dims, \
fdm_indices_strides, \
reinterpret_cast<const ToCudaType<T>::MappedType*>(update_data), \
axis, \
reinterpret_cast<ToCudaType<T>::MappedType*>(output_data)); \
} \
template <typename T>
struct ScatterElements::ComputeImpl {
Status operator()(const Tensor* data_tensor,
const Tensor* updates_tensor,
const Tensor* indices_tensor,
Tensor* output_tensor,
const int rank,
const int64_t input_data_size,
TArray<int64_t>& buffer_input_dims,
TArray<int64_t>& buffer_input_strides,
const int64_t indices_size,
TArray<int64_t>& buffer_indices_dims,
TArray<fast_divmod>& fdm_indices_strides,
const int axis) const {
T* output_data = output_tensor->template MutableData<T>();
const T* input_data = data_tensor->template Data<T>();
const T* update_data = updates_tensor->template Data<T>();
typedef typename ToCudaType<T>::MappedType CudaT;
MLDataType Tin_type = indices_tensor->DataType();
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) {
const int32_t* indices_data = indices_tensor->template Data<int32_t>();
return ScatterElementsImpl(
rank,
reinterpret_cast<const CudaT*>(input_data),
input_data_size,
buffer_input_dims,
buffer_input_strides,
indices_data,
indices_size,
buffer_indices_dims,
fdm_indices_strides,
reinterpret_cast<const CudaT*>(update_data),
axis,
reinterpret_cast<CudaT*>(output_data));
} else if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) {
const int64_t* indices_data = indices_tensor->template Data<int64_t>();
return ScatterElementsImpl(
rank,
reinterpret_cast<const CudaT*>(input_data),
input_data_size,
buffer_input_dims,
buffer_input_strides,
indices_data,
indices_size,
buffer_indices_dims,
fdm_indices_strides,
reinterpret_cast<const CudaT*>(update_data),
axis,
reinterpret_cast<CudaT*>(output_data));
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements.");
}
};
Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
const auto* data_tensor = context->Input<Tensor>(0);
@ -131,23 +147,12 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
}
MLDataType Tin_type = indices_tensor->DataType();
MLDataType T_type = data_tensor->DataType();
TYPED_FUNCTION_CALL(float)
TYPED_FUNCTION_CALL(MLFloat16)
TYPED_FUNCTION_CALL(int16_t)
TYPED_FUNCTION_CALL(int8_t)
TYPED_FUNCTION_CALL(int32_t)
TYPED_FUNCTION_CALL(int64_t)
TYPED_FUNCTION_CALL(uint8_t)
TYPED_FUNCTION_CALL(uint16_t)
TYPED_FUNCTION_CALL(uint32_t)
TYPED_FUNCTION_CALL(uint64_t)
TYPED_FUNCTION_CALL(double)
TYPED_FUNCTION_CALL(bool)
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T is not supported yet in ScatterElements.");
utils::MLTypeCallDispatcherRet<Status, ComputeImpl, float, MLFloat16, int16_t, int8_t, int32_t,
int64_t, uint8_t, uint16_t, uint32_t, uint64_t, double, bool>
t_disp(data_tensor->GetElementType());
return t_disp.Invoke(data_tensor, updates_tensor, indices_tensor, output_tensor, rank,
input_data_size, buffer_input_dims, buffer_input_strides, indices_size,
buffer_indices_dims, fdm_indices_strides, axis);
}
} // namespace cuda

View file

@ -19,9 +19,11 @@ class ScatterElements final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
private:
template <typename T>
struct ComputeImpl;
int64_t axis_;
};
} // namespace cuda
} // namespace onnxruntime