diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 2bb18f947c..0462cc9ba0 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -19,7 +19,6 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Abs", *this); CreateSimpleOpBuilder("And", *this); CreateSimpleOpBuilder("Ceil", *this); - CreateSimpleOpBuilder("Cast", *this); CreateSimpleOpBuilder("Cos", *this); CreateSimpleOpBuilder("Div", *this); CreateSimpleOpBuilder("Equal", *this); @@ -54,6 +53,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Concat", *this); } + { + CreateCastOpBuilder("Cast", *this); + } + { CreateReduceOpBuilder("ReduceMax", *this); CreateReduceOpBuilder("ReduceMean", *this); diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index b178f87b30..2afffb2305 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -50,6 +50,8 @@ const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type); void CreateSimpleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc new file mode 100644 index 0000000000..13cecc36df --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace qnn { + +class CastOpBuilder : public BaseOpBuilder { + public: + CastOpBuilder() : BaseOpBuilder("CastOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CastOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + bool is_quantized_model, + std::vector& input_names, + bool do_op_validation = false) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool is_quantized_model, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + +}; + +Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + bool is_quantized_model, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + ORT_UNUSED_PARAMETER(is_quantized_model); // Ignore in all backends. Cast should use same QNN types across backends. + + const auto& inputs = node_unit.Inputs(); + ORT_ENFORCE(inputs.size() == 1, "QNN Cast node must have a single input."); + const auto& input = inputs[0]; + + const auto& input_name = input.node_arg.Name(); + + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name; + input_names.push_back(input_name); + return Status::OK(); + } + + std::vector unpacked_tensor; + bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); + if (is_initializer_input) { + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(*input_tensor, unpacked_tensor)); + } + + Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, input_shape), + "Cannot get shape for QNN Cast node's input."); + + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED; + const auto* type_proto = input.node_arg.TypeAsProto(); + + ORT_RETURN_IF_ERROR(GetQnnDataType(false, // Do not try to get the quantized type. HTP cast supports normal types. + type_proto, + qnn_data_type)); + + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QNN_QUANTIZE_PARAMS_INIT, + std::move(input_shape), std::move(unpacked_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), + "Failed to add input tensor for QNN Cast node."); + input_names.push_back(input_name); + + return Status::OK(); +} + +Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool is_quantized_model, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(logger); + ORT_UNUSED_PARAMETER(is_quantized_model); // Ignore in all backends. Cast should use same QNN types across backends. + + const auto& outputs = node_unit.Outputs(); + ORT_ENFORCE(outputs.size() == 1, "QNN Cast node must have a single output."); + const auto& output = outputs[0]; + const auto& output_name = output.node_arg.Name(); + + const auto* type_proto = output.node_arg.TypeAsProto(); + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED; + ORT_RETURN_IF_ERROR(GetQnnDataType(false, // Do not try to get the quantized type. HTP cast supports normal types. + type_proto, + qnn_data_type)); + + std::vector output_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), + "Cannot get shape for QNN Cast node's output."); + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); + + const Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensorwrapper(output_name, + tensor_type, + qnn_data_type, + QNN_QUANTIZE_PARAMS_INIT, + std::move(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), + "Failed to add output tensor for QNN Cast node."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), + qnn_def::package_name, + GetQnnOpType(node_unit.OpType()), + std::move(input_names), + {output_name}, + {}, + do_op_validation), + "Failed to create QNN Cast node."); + + return Status::OK(); +} + +void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6e11048d09..c1a4c77d07 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -83,6 +83,7 @@ bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapp // Is NPU backend, is single node, case by case // Q/DQ nodes -- supported // Transpose nodes -- supported + // Cast nodes -- need to call CastOpBuilder::IsOpSupported if (is_npu_backend && NodeUnit::Type::SingleNode == node_unit.UnitType()) { if (IsQdqNode(node_unit)) { // Qnn has Quantize & Dequantize Op LOGS(logger, VERBOSE) << "Single Q/DQ node is supported for NPU backend. Node name: " << node_unit.Name(); @@ -95,9 +96,13 @@ bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapp return true; } - LOGS(logger, VERBOSE) << "Non-QDQ single node is not supported for NPU backend. Node name: " << node_unit.Name() - << " Op type: " << node_unit.OpType(); - return false; + // For Cast, need to call IsOpSupported (below) to validate input and output types. + // For other single non-qdq nodes, immediately return not supported. + if (node_unit.OpType() != "Cast") { + LOGS(logger, VERBOSE) << "Non-QDQ single node is not supported for NPU backend. Node name: " << node_unit.Name() + << " Op type: " << node_unit.OpType(); + return false; + } } // Non-NPU backend, quantized model not supported, but a QDQ node encountered diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc new file mode 100644 index 0000000000..58d4580cfb --- /dev/null +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +/** + * Creates a graph with a single Cast operator. + * + * \param shape The shape of the input and output. Input data is randomly generated with this shape. + * \param dst_type The destination type as an instance of the DataType enum in TensorProto. + * + * \return A function that builds the graph with the provided builder. + */ +template +static GetTestModelFn BuildCastTestCase(const std::vector& shape, + ONNX_NAMESPACE::TensorProto_DataType dst_type) { + return [shape, dst_type](ModelTestBuilder& builder) { + + // Random input data + auto input = builder.MakeInput(shape, static_cast(0), static_cast(20)); + + auto* output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {input}, {output}); + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +/** + * Runs a Cast model on the QNN CPU or HTP backend. Checks the graph node assignment, and that inference + * outputs for QNN and CPU match. + * + * \param shape The shape of the input and output. Input data is randomly generated with this shape. + * \param dst_type The destination type as an instance of the DataType enum in TensorProto. + * \param test_description Description of the test for error reporting. + * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). + * \param use_htp True to run on HTP backend. Otherwise, runs on CPU. + */ +template +static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::TensorProto_DataType dst_type, + ExpectedEPNodeAssignment expected_ep_assignment, const char* test_description, + bool use_htp) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; +#else + provider_options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; +#endif + + constexpr int expected_nodes_in_partition = 1; + RunQnnModelTest(BuildCastTestCase(shape, dst_type), + provider_options, + 13, // opset + expected_ep_assignment, + expected_nodes_in_partition, + test_description); +} + +// +// CPU tests: +// + +// Cast int32_t to float on CPU +TEST(QnnCPUBackendTests, TestCastInt32ToFloat) { + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "TestCastInt32ToFloat", false); +} + +// Cast uint8_t to float on CPU +TEST(QnnCPUBackendTests, TestCastUInt8ToFloat) { + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "TestCastUInt8ToFloat", false); +} + +// Cast float to int32_t on CPU +TEST(QnnCPUBackendTests, TestCastFloatToInt32) { + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, + "TestCastInt32ToFloat", false); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Cast int32_t to float on HTP +TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "TestCastInt32ToFloatHTP", true); +} + +// Cast uint8_t to float on HTP +TEST_F(QnnHTPBackendTests, TestCastUInt8ToFloatHTP) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "TestCastUInt8ToFloatHTP", true); +} + +// Cast float to int32_t on HTP +TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, + "TestCastFloatToInt32HTP", true); +} +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) \ No newline at end of file