mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Improve the performance of Scatter (#991)
* Improve the performance of scatter
This commit is contained in:
parent
cb25fd4d8a
commit
0f6b3ea575
1 changed files with 45 additions and 19 deletions
|
|
@ -30,7 +30,7 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Scatter);
|
||||
|
||||
template <class Tin>
|
||||
template <class Tin, class Tdata>
|
||||
Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
|
||||
const int64_t axis, Tensor* data_output) {
|
||||
const TensorShape& input_data_shape = data_input->Shape();
|
||||
|
|
@ -45,12 +45,11 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
|
|||
}
|
||||
|
||||
const auto input_elements = input_data_shape.Size();
|
||||
const auto element_bytes = data_input->DataType()->Size();
|
||||
const auto total_input_bytes = data_input->Size();
|
||||
|
||||
const uint8_t* src_base = reinterpret_cast<const uint8_t*>(data_input->DataRaw());
|
||||
uint8_t* dst_base = reinterpret_cast<uint8_t*>(data_output->MutableDataRaw());
|
||||
const bool is_string_type = data_input->DataType() == DataTypeImpl::GetType<std::string>();
|
||||
const Tdata* src_base = static_cast<const Tdata*>(data_input->DataRaw());
|
||||
Tdata* dst_base = static_cast<Tdata*>(data_output->MutableDataRaw());
|
||||
bool is_string_type = data_input->DataType() == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
// We allow runtime to re-use input for output. If input/output Tensor* are the same
|
||||
// we do not copy
|
||||
|
|
@ -110,7 +109,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
|
|||
}
|
||||
}
|
||||
|
||||
const uint8_t* update_data = reinterpret_cast<const uint8_t*>(updates_input->DataRaw());
|
||||
const Tdata* update_data = static_cast<const Tdata*>(updates_input->DataRaw());
|
||||
// For every update we compute the destination offset and copy it there
|
||||
for (int64_t index = 0; index < num_indices;) {
|
||||
const Tin axis_idx = indices_data[index];
|
||||
|
|
@ -127,16 +126,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
|
|||
}
|
||||
}
|
||||
|
||||
const size_t dst_offset_bytes = dst_offset * element_bytes;
|
||||
assert(dst_offset_bytes < total_input_bytes);
|
||||
if (is_string_type) {
|
||||
reinterpret_cast<std::string*>(dst_base)[dst_offset] =
|
||||
reinterpret_cast<const std::string*>(update_data)[index];
|
||||
} else {
|
||||
// Copy an element
|
||||
auto src_offset_bytes = index * element_bytes;
|
||||
memcpy(dst_base + dst_offset_bytes, update_data + src_offset_bytes, element_bytes);
|
||||
}
|
||||
dst_base[dst_offset] = update_data[index];
|
||||
|
||||
if (++index == num_indices) {
|
||||
break;
|
||||
|
|
@ -158,6 +148,38 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define DispatchOnIndexTypeAndTensorType(index_type, tensor_type, retval, function, ...) \
|
||||
if (tensor_type == DataTypeImpl::GetType<float>()) \
|
||||
retval = function<index_type, float>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<double>()) \
|
||||
retval = function<index_type, double>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
|
||||
retval = function<index_type, int8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
|
||||
retval = function<index_type, int16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
|
||||
retval = function<index_type, int32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
|
||||
retval = function<index_type, int64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
|
||||
retval = function<index_type, uint8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
|
||||
retval = function<index_type, uint16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
|
||||
retval = function<index_type, uint32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
|
||||
retval = function<index_type, uint64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
|
||||
retval = function<index_type, bool>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
|
||||
retval = function<index_type, MLFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
|
||||
retval = function<index_type, BFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
|
||||
retval = function<index_type, std::string>(__VA_ARGS__); \
|
||||
else \
|
||||
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
|
||||
|
||||
Status Scatter::Compute(OpKernelContext* context) const {
|
||||
const auto* data_input = context->Input<Tensor>(0);
|
||||
const auto& input_data_shape = data_input->Shape();
|
||||
|
|
@ -203,12 +225,16 @@ Status Scatter::Compute(OpKernelContext* context) const {
|
|||
auto* data_output = context->Output(0, input_data_shape);
|
||||
|
||||
MLDataType Tind_type = indices_input->DataType();
|
||||
MLDataType Tdata_type = data_input->DataType();
|
||||
Status status;
|
||||
if (Tind_type == DataTypeImpl::GetType<int32_t>()) {
|
||||
return CopyScatterData<int32_t>(data_input, indices_input, updates_input, axis, data_output);
|
||||
DispatchOnIndexTypeAndTensorType(int32_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output);
|
||||
} else if (Tind_type == DataTypeImpl::GetType<int64_t>()) {
|
||||
return CopyScatterData<int64_t>(data_input, indices_input, updates_input, axis, data_output);
|
||||
DispatchOnIndexTypeAndTensorType(int64_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
|
||||
}
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue