mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix the issue for Gather int64 indices handling (#23274)
### Description Fix the issue for Gather int64 indices handling. Make it still insert Cast node if it's non-quantized Gather node.
This commit is contained in:
parent
5b9c968eaa
commit
76d6345f0b
3 changed files with 72 additions and 28 deletions
|
|
@ -126,23 +126,10 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
|
|||
std::vector<std::string>& input_names,
|
||||
bool do_op_validation) {
|
||||
const auto& input_name = indices_input.node_arg.Name();
|
||||
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
|
||||
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
|
||||
input_names.push_back(input_name);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorInfo indices_info = {};
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));
|
||||
|
||||
const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
|
||||
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
|
||||
ORT_RETURN_IF(is_npu_backend &&
|
||||
(indices_info.qnn_data_type == QNN_DATATYPE_INT_64) &&
|
||||
!(indices_info.is_initializer || is_graph_input),
|
||||
"HTP backend doesn't support a Gather* op with a dynamic int64 input activation ",
|
||||
"unless it is a graph input.");
|
||||
|
||||
std::vector<uint8_t> qnn_indices_bytes;
|
||||
|
||||
// Get raw bytes for static indices.
|
||||
|
|
@ -165,27 +152,22 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
|
|||
|
||||
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name);
|
||||
std::vector<uint32_t> cast_output_shape(indices_info.shape);
|
||||
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
|
||||
std::move(indices_info.shape), std::move(qnn_indices_bytes));
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
|
||||
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
|
||||
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
|
||||
} else {
|
||||
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
|
||||
std::move(indices_info.shape), std::move(qnn_indices_bytes));
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
|
||||
}
|
||||
|
||||
// Insert QNN Cast op to convert dynamic indices from int64 to int32.
|
||||
std::string indices_input_name(input_name);
|
||||
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
|
||||
assert(!indices_info.is_initializer);
|
||||
|
||||
indices_input_name = input_name + "_ort_qnn_ep_cast";
|
||||
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
|
||||
QnnQuantParamsWrapper(), std::move(cast_output_shape));
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
|
||||
QNN_OP_PACKAGE_NAME_QTI_AISW,
|
||||
"Cast",
|
||||
{input_name},
|
||||
{indices_input_name},
|
||||
{},
|
||||
do_op_validation),
|
||||
"Failed to add node.");
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddInt64CastNode(input_name, indices_input_name,
|
||||
std::move(cast_output_shape),
|
||||
do_op_validation));
|
||||
}
|
||||
|
||||
input_names.push_back(indices_input_name);
|
||||
|
|
|
|||
|
|
@ -224,6 +224,24 @@ class QnnModelWrapper {
|
|||
tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output);
|
||||
}
|
||||
|
||||
Status AddInt64CastNode(const std::string& input_name, std::string& cast_output_name,
|
||||
std::vector<uint32_t>&& cast_output_shape, bool do_op_validation) {
|
||||
cast_output_name = input_name + "_ort_qnn_ep_cast";
|
||||
QnnTensorWrapper cast_output(cast_output_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
|
||||
QnnQuantParamsWrapper(), std::move(cast_output_shape));
|
||||
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
|
||||
ORT_RETURN_IF_NOT(CreateQnnNode(cast_output_name,
|
||||
QNN_OP_PACKAGE_NAME_QTI_AISW,
|
||||
"Cast",
|
||||
{input_name},
|
||||
{cast_output_name},
|
||||
{},
|
||||
do_op_validation),
|
||||
"Failed to add node.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
|
||||
std::vector<uint8_t>& unpacked_tensor) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,15 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) {
|
|||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// negative indices
|
||||
TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_NegativeIndices) {
|
||||
RunQDQGatherOpTest<uint8_t, int32_t>(TestInputDef<float>({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}),
|
||||
TestInputDef<int32_t>({2, 2}, true, {-1, 1, 1, 2}),
|
||||
{utils::MakeAttribute("axis", static_cast<int64_t>(0))},
|
||||
13,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
|
||||
// nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP.
|
||||
//
|
||||
|
|
@ -148,6 +157,41 @@ TEST_F(QnnHTPBackendTests, DISABLED_GatherOp_IndicesStaticInt32_Axis1) {
|
|||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// Runs a non-QDQ model on HTP and compares output to CPU EP.
|
||||
template <typename InputType1 = float, typename InputType2 = float>
|
||||
static void RunOpTest(const std::string& op_type,
|
||||
const TestInputDef<InputType1>& input_def_1,
|
||||
const TestInputDef<InputType2>& input_defs_2,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
|
||||
int opset_version,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment,
|
||||
const std::string& op_domain = kOnnxDomain,
|
||||
float fp32_abs_err = 1e-3f) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
// Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs.
|
||||
RunQnnModelTest(BuildOpTestCase<InputType1, InputType2>(op_type, {input_def_1}, {input_defs_2}, attrs, op_domain),
|
||||
provider_options,
|
||||
opset_version,
|
||||
expected_ep_assignment,
|
||||
fp32_abs_err);
|
||||
}
|
||||
|
||||
// Non-QDQ model, Gather with static input and dynamic int64 indices
|
||||
TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64) {
|
||||
RunOpTest<float, int64_t>("Gather",
|
||||
TestInputDef<float>({3, 2}, true, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}),
|
||||
TestInputDef<int64_t>({2, 2}, false, {0, 1, 1, 2}),
|
||||
{utils::MakeAttribute("axis", static_cast<int64_t>(0))},
|
||||
13,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue