mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Refactor - ScatterElements (#3559)
Refactor ScatterElements using MLTypeCallDispatcherRet to refactor
This commit is contained in:
parent
2579a72a88
commit
e233e6ba45
2 changed files with 62 additions and 55 deletions
|
|
@ -34,44 +34,60 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ScatterElements);
|
||||
|
||||
#define TYPED_FUNCTION_CALL(T) \
|
||||
if (utils::IsPrimitiveDataType<T>(T_type)) { \
|
||||
T* output_data = output_tensor->template MutableData<T>(); \
|
||||
const T* input_data = data_tensor->template Data<T>(); \
|
||||
const T* update_data = updates_tensor->template Data<T>(); \
|
||||
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) { \
|
||||
const int32_t* indices_data = indices_tensor->template Data<int32_t>(); \
|
||||
return ScatterElementsImpl( \
|
||||
rank, \
|
||||
reinterpret_cast<const ToCudaType<T>::MappedType*>(input_data), \
|
||||
input_data_size, \
|
||||
buffer_input_dims, \
|
||||
buffer_input_strides, \
|
||||
indices_data, \
|
||||
indices_size, \
|
||||
buffer_indices_dims, \
|
||||
fdm_indices_strides, \
|
||||
reinterpret_cast<const ToCudaType<T>::MappedType*>(update_data), \
|
||||
axis, \
|
||||
reinterpret_cast<ToCudaType<T>::MappedType*>(output_data)); \
|
||||
} \
|
||||
if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) { \
|
||||
const int64_t* indices_data = indices_tensor->template Data<int64_t>(); \
|
||||
return ScatterElementsImpl( \
|
||||
rank, \
|
||||
reinterpret_cast<const ToCudaType<T>::MappedType*>(input_data), \
|
||||
input_data_size, \
|
||||
buffer_input_dims, \
|
||||
buffer_input_strides, \
|
||||
indices_data, \
|
||||
indices_size, \
|
||||
buffer_indices_dims, \
|
||||
fdm_indices_strides, \
|
||||
reinterpret_cast<const ToCudaType<T>::MappedType*>(update_data), \
|
||||
axis, \
|
||||
reinterpret_cast<ToCudaType<T>::MappedType*>(output_data)); \
|
||||
} \
|
||||
template <typename T>
|
||||
struct ScatterElements::ComputeImpl {
|
||||
Status operator()(const Tensor* data_tensor,
|
||||
const Tensor* updates_tensor,
|
||||
const Tensor* indices_tensor,
|
||||
Tensor* output_tensor,
|
||||
const int rank,
|
||||
const int64_t input_data_size,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const int64_t indices_size,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& fdm_indices_strides,
|
||||
const int axis) const {
|
||||
T* output_data = output_tensor->template MutableData<T>();
|
||||
const T* input_data = data_tensor->template Data<T>();
|
||||
const T* update_data = updates_tensor->template Data<T>();
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
MLDataType Tin_type = indices_tensor->DataType();
|
||||
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) {
|
||||
const int32_t* indices_data = indices_tensor->template Data<int32_t>();
|
||||
return ScatterElementsImpl(
|
||||
rank,
|
||||
reinterpret_cast<const CudaT*>(input_data),
|
||||
input_data_size,
|
||||
buffer_input_dims,
|
||||
buffer_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
buffer_indices_dims,
|
||||
fdm_indices_strides,
|
||||
reinterpret_cast<const CudaT*>(update_data),
|
||||
axis,
|
||||
reinterpret_cast<CudaT*>(output_data));
|
||||
} else if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) {
|
||||
const int64_t* indices_data = indices_tensor->template Data<int64_t>();
|
||||
return ScatterElementsImpl(
|
||||
rank,
|
||||
reinterpret_cast<const CudaT*>(input_data),
|
||||
input_data_size,
|
||||
buffer_input_dims,
|
||||
buffer_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
buffer_indices_dims,
|
||||
fdm_indices_strides,
|
||||
reinterpret_cast<const CudaT*>(update_data),
|
||||
axis,
|
||||
reinterpret_cast<CudaT*>(output_data));
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements.");
|
||||
}
|
||||
};
|
||||
|
||||
Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* data_tensor = context->Input<Tensor>(0);
|
||||
|
|
@ -131,23 +147,12 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
|||
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
|
||||
}
|
||||
|
||||
MLDataType Tin_type = indices_tensor->DataType();
|
||||
MLDataType T_type = data_tensor->DataType();
|
||||
|
||||
TYPED_FUNCTION_CALL(float)
|
||||
TYPED_FUNCTION_CALL(MLFloat16)
|
||||
TYPED_FUNCTION_CALL(int16_t)
|
||||
TYPED_FUNCTION_CALL(int8_t)
|
||||
TYPED_FUNCTION_CALL(int32_t)
|
||||
TYPED_FUNCTION_CALL(int64_t)
|
||||
TYPED_FUNCTION_CALL(uint8_t)
|
||||
TYPED_FUNCTION_CALL(uint16_t)
|
||||
TYPED_FUNCTION_CALL(uint32_t)
|
||||
TYPED_FUNCTION_CALL(uint64_t)
|
||||
TYPED_FUNCTION_CALL(double)
|
||||
TYPED_FUNCTION_CALL(bool)
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T is not supported yet in ScatterElements.");
|
||||
utils::MLTypeCallDispatcherRet<Status, ComputeImpl, float, MLFloat16, int16_t, int8_t, int32_t,
|
||||
int64_t, uint8_t, uint16_t, uint32_t, uint64_t, double, bool>
|
||||
t_disp(data_tensor->GetElementType());
|
||||
return t_disp.Invoke(data_tensor, updates_tensor, indices_tensor, output_tensor, rank,
|
||||
input_data_size, buffer_input_dims, buffer_input_strides, indices_size,
|
||||
buffer_indices_dims, fdm_indices_strides, axis);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@ class ScatterElements final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct ComputeImpl;
|
||||
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue