Fix the issue for Gather int64 indices handling (#23274)

### Description
Fix the issue for Gather int64 indices handling. Make it still insert Cast node if it's non-quantized Gather node.
This commit is contained in:
Hector Li 2025-01-08 12:52:08 -08:00 committed by GitHub
parent 5b9c968eaa
commit 76d6345f0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 72 additions and 28 deletions

View file

@ -126,23 +126,10 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
std::vector<std::string>& input_names,
bool do_op_validation) {
const auto& input_name = indices_input.node_arg.Name();
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
input_names.push_back(input_name);
return Status::OK();
}
TensorInfo indices_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));
const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
ORT_RETURN_IF(is_npu_backend &&
(indices_info.qnn_data_type == QNN_DATATYPE_INT_64) &&
!(indices_info.is_initializer || is_graph_input),
"HTP backend doesn't support a Gather* op with a dynamic int64 input activation ",
"unless it is a graph input.");
std::vector<uint8_t> qnn_indices_bytes;
// Get raw bytes for static indices.
@ -165,27 +152,22 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name);
std::vector<uint32_t> cast_output_shape(indices_info.shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape), std::move(qnn_indices_bytes));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
} else {
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape), std::move(qnn_indices_bytes));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}
// Insert QNN Cast op to convert dynamic indices from int64 to int32.
std::string indices_input_name(input_name);
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
assert(!indices_info.is_initializer);
indices_input_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{indices_input_name},
{},
do_op_validation),
"Failed to add node.");
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddInt64CastNode(input_name, indices_input_name,
std::move(cast_output_shape),
do_op_validation));
}
input_names.push_back(indices_input_name);

View file

@ -224,6 +224,24 @@ class QnnModelWrapper {
tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output);
}
Status AddInt64CastNode(const std::string& input_name, std::string& cast_output_name,
std::vector<uint32_t>&& cast_output_shape, bool do_op_validation) {
cast_output_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(cast_output_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(CreateQnnNode(cast_output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{cast_output_name},
{},
do_op_validation),
"Failed to add node.");
return Status::OK();
}
Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
std::vector<uint8_t>& unpacked_tensor) const;

View file

@ -119,6 +119,15 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) {
ExpectedEPNodeAssignment::All);
}
// negative indices
TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_NegativeIndices) {
RunQDQGatherOpTest<uint8_t, int32_t>(TestInputDef<float>({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}),
TestInputDef<int32_t>({2, 2}, true, {-1, 1, 1, 2}),
{utils::MakeAttribute("axis", static_cast<int64_t>(0))},
13,
ExpectedEPNodeAssignment::All);
}
// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
// nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP.
//
@ -148,6 +157,41 @@ TEST_F(QnnHTPBackendTests, DISABLED_GatherOp_IndicesStaticInt32_Axis1) {
ExpectedEPNodeAssignment::All);
}
// Runs a non-QDQ model on HTP and compares output to CPU EP.
template <typename InputType1 = float, typename InputType2 = float>
static void RunOpTest(const std::string& op_type,
const TestInputDef<InputType1>& input_def_1,
const TestInputDef<InputType2>& input_defs_2,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
int opset_version,
ExpectedEPNodeAssignment expected_ep_assignment,
const std::string& op_domain = kOnnxDomain,
float fp32_abs_err = 1e-3f) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
// Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs.
RunQnnModelTest(BuildOpTestCase<InputType1, InputType2>(op_type, {input_def_1}, {input_defs_2}, attrs, op_domain),
provider_options,
opset_version,
expected_ep_assignment,
fp32_abs_err);
}
// Non-QDQ model, Gather with static input and dynamic int64 indices
TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64) {
RunOpTest<float, int64_t>("Gather",
TestInputDef<float>({3, 2}, true, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}),
TestInputDef<int64_t>({2, 2}, false, {0, 1, 1, 2}),
{utils::MakeAttribute("axis", static_cast<int64_t>(0))},
13,
ExpectedEPNodeAssignment::All);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
} // namespace test
} // namespace onnxruntime