Improve the performance of Scatter (#991)

* Improve the performance of scatter
This commit is contained in:
Yufeng Li 2019-05-14 13:35:15 -07:00 committed by Changming Sun
parent cb25fd4d8a
commit 0f6b3ea575

View file

@ -30,7 +30,7 @@ ONNX_CPU_OPERATOR_KERNEL(
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);
template <class Tin>
template <class Tin, class Tdata>
Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
const TensorShape& input_data_shape = data_input->Shape();
@ -45,12 +45,11 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
}
const auto input_elements = input_data_shape.Size();
const auto element_bytes = data_input->DataType()->Size();
const auto total_input_bytes = data_input->Size();
const uint8_t* src_base = reinterpret_cast<const uint8_t*>(data_input->DataRaw());
uint8_t* dst_base = reinterpret_cast<uint8_t*>(data_output->MutableDataRaw());
const bool is_string_type = data_input->DataType() == DataTypeImpl::GetType<std::string>();
const Tdata* src_base = static_cast<const Tdata*>(data_input->DataRaw());
Tdata* dst_base = static_cast<Tdata*>(data_output->MutableDataRaw());
bool is_string_type = data_input->DataType() == DataTypeImpl::GetType<std::string>();
// We allow runtime to re-use input for output. If input/output Tensor* are the same
// we do not copy
@ -110,7 +109,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
}
}
const uint8_t* update_data = reinterpret_cast<const uint8_t*>(updates_input->DataRaw());
const Tdata* update_data = static_cast<const Tdata*>(updates_input->DataRaw());
// For every update we compute the destination offset and copy it there
for (int64_t index = 0; index < num_indices;) {
const Tin axis_idx = indices_data[index];
@ -127,16 +126,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
}
}
const size_t dst_offset_bytes = dst_offset * element_bytes;
assert(dst_offset_bytes < total_input_bytes);
if (is_string_type) {
reinterpret_cast<std::string*>(dst_base)[dst_offset] =
reinterpret_cast<const std::string*>(update_data)[index];
} else {
// Copy an element
auto src_offset_bytes = index * element_bytes;
memcpy(dst_base + dst_offset_bytes, update_data + src_offset_bytes, element_bytes);
}
dst_base[dst_offset] = update_data[index];
if (++index == num_indices) {
break;
@ -158,6 +148,38 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
return Status::OK();
}
#define DispatchOnIndexTypeAndTensorType(index_type, tensor_type, retval, function, ...) \
if (tensor_type == DataTypeImpl::GetType<float>()) \
retval = function<index_type, float>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<double>()) \
retval = function<index_type, double>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
retval = function<index_type, int8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
retval = function<index_type, int16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
retval = function<index_type, int32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
retval = function<index_type, int64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
retval = function<index_type, uint8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
retval = function<index_type, uint16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
retval = function<index_type, uint32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
retval = function<index_type, uint64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
retval = function<index_type, bool>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
retval = function<index_type, MLFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
retval = function<index_type, BFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
retval = function<index_type, std::string>(__VA_ARGS__); \
else \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
Status Scatter::Compute(OpKernelContext* context) const {
const auto* data_input = context->Input<Tensor>(0);
const auto& input_data_shape = data_input->Shape();
@ -203,12 +225,16 @@ Status Scatter::Compute(OpKernelContext* context) const {
auto* data_output = context->Output(0, input_data_shape);
MLDataType Tind_type = indices_input->DataType();
MLDataType Tdata_type = data_input->DataType();
Status status;
if (Tind_type == DataTypeImpl::GetType<int32_t>()) {
return CopyScatterData<int32_t>(data_input, indices_input, updates_input, axis, data_output);
DispatchOnIndexTypeAndTensorType(int32_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output);
} else if (Tind_type == DataTypeImpl::GetType<int64_t>()) {
return CopyScatterData<int64_t>(data_input, indices_input, updates_input, axis, data_output);
DispatchOnIndexTypeAndTensorType(int64_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t");
return status;
}
} // namespace onnxruntime