diff --git a/include/onnxruntime/core/optimizer/graph_transformer.h b/include/onnxruntime/core/optimizer/graph_transformer.h index 331ef5b530..61ed038eea 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer.h +++ b/include/onnxruntime/core/optimizer/graph_transformer.h @@ -6,6 +6,7 @@ #include "core/common/common.h" #include "core/common/inlined_containers.h" +#include "core/framework/data_types.h" #include "core/graph/graph_viewer.h" #include "core/optimizer/graph_transformer_level.h" @@ -69,4 +70,27 @@ class GraphTransformer { const std::string name_; const InlinedHashSet compatible_provider_types_; }; + +/** + * @brief Immutable object to identify a kernel registration. + * + * This data structure is used by the graph transformers to check whether + * a kernel is registered with the execution provider (i.e. has an + * implementation). If not, the transformer can not generate a node with + * such kernel. + */ +struct OpKernelRegistryId { + const std::string op_type_; + const std::string domain_; + const int version_; + const InlinedHashMap type_constraints_; + + OpKernelRegistryId( + const std::basic_string_view& op, + const std::basic_string_view& domain, + const int version, + std::initializer_list> init_list) + : op_type_(op), domain_(domain), version_(version), type_constraints_(init_list) {} +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fe7698d040..7196465e5d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -350,7 +350,11 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault); - transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + if (nhwc_transformer->IsActive()) { + transformers.emplace_back(std::move(nhwc_transformer)); + } // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So @@ -421,9 +425,12 @@ InlinedVector> GenerateTransformersForMinimalB // currently the only level 3 optimizer is the NhwcTransformer which is fully supported at runtime if (!saving) { #ifndef DISABLE_CONTRIB_OPS - const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault); - transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + if (nhwc_transformer->IsActive()) { + transformers.emplace_back(std::move(nhwc_transformer)); + } #else ORT_UNUSED_PARAMETER(cpu_execution_provider); #endif diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index 29c8de161a..1eecc0d586 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -2,18 +2,165 @@ // Licensed under the MIT License. #include +#include "core/mlas/inc/mlas.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/nhwc_transformer.h" #include "core/optimizer/utils.h" -#include "core/optimizer/transpose_optimizer/optimizer_utils.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; using namespace onnx_layout_transformation; +using namespace nhwc_map_internal; namespace onnxruntime { +static inline const OpTransformInfo* +NhwcConvLookup( + const OpTransformMap& conv_table, + const api::GraphRef& graph, + api::NodeRef& node) { + const auto& optype = node.OpType(); + const auto& domain = node.Domain(); + const auto inputs = node.Inputs(); + if (inputs.empty()) { + // node with no input, can't be our transformation candidate. + return nullptr; + } + const auto info = graph.GetValueInfo(inputs[0]); + const api::DataType dtype = info->DType(); + OpIdInfo key{optype, domain, dtype}; + + const auto iter = conv_table.find(key); + if (iter == conv_table.end()) { + return nullptr; + } + return &(iter->second); +} + +NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept + : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { + if (!cpu_kernel_registry) { + // This is a CPU op nodes optimizer, not useful if cpu EP is not available. + return; + } + + // + // Constructing a mapping table from operators to be transformed to their target. + // Make sure that the new nodes we are about to create during graph transformation, + // their kernels are available in the cpu EP. + // + + { + // int8 qconv -> int8 nhwc qconv + OpKernelRegistryId qconv_int8{ + "QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType()}}}}; + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_, + qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::INT8), + OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true}); + conv_table_.emplace( + OpIdInfo("QLinearConv", kMSDomain, api::DataType::INT8), + OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true}); + } + } + + { + // uint8 qconv -> int8 nhwc qconv + OpKernelRegistryId qconv_uint8{ + "QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType()}}}}; + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_, + qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::UINT8), + OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true}); + conv_table_.emplace( + OpIdInfo("QLinearConv", kMSDomain, api::DataType::UINT8), + OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true}); + } + } + + { + // fp16 conv -> fp16 nhwc conv + OpKernelRegistryId nhwc_conv_fp16{ + "NhwcFusedConv", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType()}}}}; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, + nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("Conv", kOnnxDomain, api::DataType::FLOAT16), + OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false}); + conv_table_.emplace( + OpIdInfo("FusedConv", kMSDomain, api::DataType::FLOAT16), + OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false}); + } + } + + { + // fp16 MaxPool -> fp16 nhwc MaxPool + OpKernelRegistryId nhwc_maxpool_fp16{ + "MaxPool", kMSInternalNHWCDomain, 12, {{"T", {DataTypeImpl::GetTensorType()}}}}; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, + nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("MaxPool", kOnnxDomain, api::DataType::FLOAT16), + OpTransformInfo{nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, nhwc_maxpool_fp16.version_, false}); + } + } + + { + // fp16 AveragePool -> fp16 nhwc AveragePool + OpKernelRegistryId nhwc_avgpool_fp16{ + "AveragePool", kMSInternalNHWCDomain, 11, {{"T", {DataTypeImpl::GetTensorType()}}}}; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, + nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("AveragePool", kOnnxDomain, api::DataType::FLOAT16), + OpTransformInfo{nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, nhwc_avgpool_fp16.version_, false}); + } + } + + { + // fp16 GlobalAveragePool -> fp16 nhwc GlobalAveragePool + OpKernelRegistryId nhwc_gavgpool_fp16{ + "GlobalAveragePool", kMSInternalNHWCDomain, 1, {{"T", {DataTypeImpl::GetTensorType()}}}}; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, + nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info); + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + conv_table_.emplace( + OpIdInfo("GlobalAveragePool", kOnnxDomain, api::DataType::FLOAT16), + OpTransformInfo{nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, nhwc_gavgpool_fp16.version_, false}); + } + } +}; + Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { #if defined(ORT_MINIMAL_BUILD) // update the producer/consumer info as previous optimizations may have invalidated it. @@ -36,41 +183,44 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - // Only QLinearConv needs to be handled explicitly. The rest will be transformed if needed during transpose - // optimization. - if (node->OpType() == "QLinearConv") { - auto domain = node->Domain(); - - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } - - // Skip if already transformed - if (node->GetAttributeIntDefault("channels_last", 0) == 1) { - continue; - } - - // Skip if unknown rank - auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape(); - if (shape == nullptr) { - continue; - } - - // Convert to channels last - size_t rank = shape->dim_size(); - node->SetAttributeInt("channels_last", 1); - - std::vector input_perm = ChannelFirstToLastPerm(rank); - std::vector output_perm = ChannelLastToFirstPerm(rank); - WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); - - if (domain != kMSDomain) { - SwapNodeOpTypeDomainAndSinceVersion(*api_graph, *node, "QLinearConv", kMSDomain, 1); - } - - modified = true; + // Only Conv and QLinearConv needs to be handled explicitly. The rest will be + // transformed if needed during transpose optimization. + const auto* transform = NhwcConvLookup(conv_table_, *api_graph, *node); + if (nullptr == transform) { + continue; } + + // Skip if already transformed + if (transform->has_channels_last_attrib_ && + node->GetAttributeIntDefault("channels_last", 0) == 1) { + continue; + } + + // Skip if unknown rank + auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape(); + if (shape == nullptr) { + continue; + } + + // Convert to channels last + if (transform->has_channels_last_attrib_) { + node->SetAttributeInt("channels_last", 1); + } + size_t rank = shape->dim_size(); + std::vector input_perm = ChannelFirstToLastPerm(rank); + std::vector output_perm = ChannelLastToFirstPerm(rank); + WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); + + // Replace the operator if needed + if (node->Domain() != transform->domain_ || + node->OpType() != transform->optype_ || + node->SinceVersion() != transform->version_) { + SwapNodeOpTypeDomainAndSinceVersion( + *api_graph, *node, transform->optype_, + transform->domain_, transform->version_); + } + + modified = true; } if (modified) { diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index a435a8b946..a3ea30f325 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -5,7 +5,64 @@ #include "core/common/common.h" #include "core/framework/execution_provider.h" +#include "core/framework/kernel_registry.h" #include "core/optimizer/graph_transformer.h" +#include "core/optimizer/transpose_optimizer/optimizer_utils.h" + +// +// Data structures internal to nhwc transformer implementation. +// Maybe we should use Pimpl Idiom to hide all these into +// an implementation class. But it would add an extra pointer +// chasing during runtime. +// +namespace nhwc_map_internal { + +/** + * @brief For identifying layout sensive operators + * as candidates for transforming to NHWC ops. + */ +struct OpIdInfo { + const std::string optype_; + const std::string domain_; + const onnx_layout_transformation::api::DataType data_type_; + + OpIdInfo( + const std::basic_string_view& op, + const std::basic_string_view& domain, + onnx_layout_transformation::api::DataType data_type) + : optype_(op), domain_(domain), data_type_(data_type) {} + + bool operator==(const OpIdInfo& other) const { + return optype_ == other.optype_ && domain_ == other.domain_ && data_type_ == other.data_type_; + } +}; + +/** + * @brief Hash function for \ref OpIdInfo + */ +class OpIdHash { + public: + size_t operator()(const OpIdInfo& op) const { + size_t h1 = std::hash{}(op.optype_); + size_t h2 = std::hash{}(op.domain_); + size_t h3 = size_t(op.data_type_); + return h2 ^ (h1 << 4) ^ (h3 << 16); + } +}; + +/** + * @brief Information needed for operator layout transformation + */ +struct OpTransformInfo { + const std::string optype_; + const std::string domain_; + const int version_; + const bool has_channels_last_attrib_; +}; + +using OpTransformMap = std::unordered_map; + +} // namespace nhwc_map_internal namespace onnxruntime { @@ -17,14 +74,29 @@ and inserts nodes to transpose tensors as needed. */ class NhwcTransformer : public GraphTransformer { private: - AllocatorPtr cpu_allocator_; - public: - explicit NhwcTransformer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)){}; + explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept; + + /** + * @brief Usually called right after constructor, it shows whether + * this transformer should be used under current hardware configuration. + * + * @return whether this transformer would be useful under current hardware config + */ + bool IsActive() { + return !conv_table_.empty(); + } private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + AllocatorPtr cpu_allocator_; + + /** + * A mapping table to identify operators that need to be transformed, and map + * them to the new operators that accept NHWC layout + */ + nhwc_map_internal::OpTransformMap conv_table_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 0959e7e61c..7820c593fb 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -129,7 +129,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { weights_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8"); weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims()); - weights_proto_u8.set_raw_data(w_temp.data(), w_temp.size()); + weights_proto_u8.set_raw_data(w_temp.data(), static_cast(w_temp.size())); input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; @@ -140,7 +140,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8"); r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims()); - r_proto_u8.set_raw_data(r_temp.data(), r_temp.size()); + r_proto_u8.set_raw_data(r_temp.data(), static_cast(r_temp.size())); input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8); ONNX_NAMESPACE::TensorProto r_zp_proto_u8; diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index 65a037cfa0..3c911c7404 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -172,6 +172,10 @@ std::optional> ApiValueInfo::Shape() const { api::DataType ApiValueInfo::DType() const { const auto* type = node_arg_.TypeAsProto(); + if (!type) { + return api::DataType::UNDEFINED; + } + if (!utils::HasTensorType(*type)) { return api::DataType::UNDEFINED; } diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index e23e792f81..d4bf25ab64 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -560,7 +560,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { static_cast(output_count), static_cast(group_output_channels), static_cast(kernel_dim), - 1, &gemm_params, thread_pool); + 1, &gemm_params, nullptr); } } }; diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index 413f456b80..c254d340cd 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -6,7 +6,7 @@ #include "gtest/gtest.h" #include "graph_transform_test_builder.h" - +#include "core/mlas/inc/mlas.h" #include "core/graph/graph.h" namespace onnxruntime { @@ -516,6 +516,167 @@ TEST(NhwcTransformerTests, ConvMixTensorRanks) { TransformerLevel::Level3); } +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +std::vector randomfp16(const std::vector& shape, MLFloat16 min, MLFloat16 max) { + std::vector val(detail::SizeFromDims(shape)); + float start = min.ToFloat(); + float end = max.ToFloat(); + float step = (end - start) / 128; + float value = start; + for (size_t i = 0; i < val.size(); ++i) { + value += step; + if (value > end) { + value = start; + } + val[i] = MLFloat16(value); + } + return val; +} + +template <> +NodeArg* ModelTestBuilder::MakeInput(const std::vector& shape, MLFloat16 min, MLFloat16 max) { + return MakeInput(shape, randomfp16(shape, min, max)); +} + +template <> +NodeArg* ModelTestBuilder::MakeInitializer(const std::vector& shape, MLFloat16 min, MLFloat16 max) { + return MakeInitializer(shape, randomfp16(shape, min, max)); +} + +TEST(NhwcTransformerTests, ConvFp16) { + auto test_case = [&](const std::vector& input_shape, const std::vector& weights_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* output_arg = builder.MakeOutput(); + auto* weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f)); + + builder.AddConvNode(input_arg, weight_arg, output_arg); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({1, 12, 37}, {32, 12, 5}); + test_case({1, 23, 13, 13}, {30, 23, 3, 3}); + test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}); +} + +TEST(NhwcTransformerTests, ConvMaxPoolFp16) { + auto test_case = [&](const std::vector& input_shape, const std::vector& weights_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* conv_output_arg = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + auto* conv_weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f)); + + builder.AddConvNode(input_arg, conv_weight_arg, conv_output_arg); + Node& pool_node = builder.AddNode("MaxPool", {conv_output_arg}, {output_arg}); + std::vector pads((weights_shape.size() - 2) * 2, 1); + pool_node.AddAttribute("pads", pads); + std::vector kernel_shape(weights_shape.size() - 2, 3); + pool_node.AddAttribute("kernel_shape", kernel_shape); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1); + EXPECT_EQ(op_to_count["com.ms.internal.nhwc.MaxPool"], 1); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({5, 12, 37}, {128, 12, 5}); + test_case({3, 14, 13, 13}, {64, 14, 3, 3}); + test_case({1, 15, 11, 13, 15}, {31, 15, 5, 3, 3}); +} + +TEST(NhwcTransformerTests, ConvGlobalAveragePoolFp16) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* conv1_output_arg = builder.MakeIntermediate(); + auto* conv2_output_arg = builder.MakeIntermediate(); + auto* gavgpool1_output_arg = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + auto* conv1_weight_arg = builder.MakeInitializer({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* conv2_weight_arg = builder.MakeInitializer({16, 30, 1, 1}, MLFloat16(-1.5f), MLFloat16(1.5f)); + + Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg); + conv1_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + + builder.AddNode("GlobalAveragePool", {conv1_output_arg}, {gavgpool1_output_arg}); + builder.AddConvNode(gavgpool1_output_arg, conv2_weight_arg, conv2_output_arg); + builder.AddNode("GlobalAveragePool", {conv2_output_arg}, {output_arg}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2); + EXPECT_EQ(op_to_count["com.ms.internal.nhwc.GlobalAveragePool"], 2); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); +} + +TEST(NhwcTransformerTests, ConvAveragePoolFp16) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* conv1_output_arg = builder.MakeIntermediate(); + auto* conv2_output_arg = builder.MakeIntermediate(); + auto* avgpool1_output_arg = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + auto* conv1_weight_arg = builder.MakeInitializer({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f)); + auto* conv2_weight_arg = builder.MakeInitializer({16, 30, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f)); + + Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg); + conv1_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + Node& avgpool_node1 = builder.AddNode( + "AveragePool", {conv1_output_arg}, {avgpool1_output_arg}); + avgpool_node1.AddAttribute("kernel_shape", std::vector{3, 3}); + avgpool_node1.AddAttribute("pads", std::vector{1, 1, 1, 1}); + + builder.AddConvNode(avgpool1_output_arg, conv2_weight_arg, conv2_output_arg); + Node& avgpool_node2 = builder.AddNode( + "AveragePool", {conv2_output_arg}, {output_arg}); + avgpool_node2.AddAttribute("kernel_shape", std::vector{3, 3}); + avgpool_node2.AddAttribute("pads", std::vector{1, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2); + EXPECT_EQ(op_to_count["com.ms.internal.nhwc.AveragePool"], 2); + EXPECT_EQ(op_to_count["Transpose"], 2); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); +} + +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED + #endif // DISABLE_CONTRIB_OPS } // namespace test