NHWC graph optimizer (#15724)

### Description

Augment nhwc graph optimizer to accommodate fp16 operators.


### Motivation and Context

With new fp16 conv operator added. This operator prefers NHWC data
layout. We need to augment existing graph optimizers to better utilize
the new operator.
This commit is contained in:
Chen Fu 2023-05-01 08:44:07 -07:00 committed by GitHub
parent d35850c142
commit 0e9472d391
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 464 additions and 46 deletions

View file

@ -6,6 +6,7 @@
#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/data_types.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/graph_transformer_level.h"
@ -69,4 +70,27 @@ class GraphTransformer {
const std::string name_;
const InlinedHashSet<std::string_view> compatible_provider_types_;
};
/**
* @brief Immutable object to identify a kernel registration.
*
* This data structure is used by the graph transformers to check whether
* a kernel is registered with the execution provider (i.e. has an
* implementation). If not, the transformer can not generate a node with
* such kernel.
*/
struct OpKernelRegistryId {
const std::string op_type_;
const std::string domain_;
const int version_;
const InlinedHashMap<std::string, MLDataType> type_constraints_;
OpKernelRegistryId(
const std::basic_string_view<char>& op,
const std::basic_string_view<char>& domain,
const int version,
std::initializer_list<std::pair<std::string, MLDataType>> init_list)
: op_type_(op), domain_(domain), version_(version), type_constraints_(init_list) {}
};
} // namespace onnxruntime

View file

@ -350,7 +350,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<NchwcTransformer>());
}
auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault);
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}
// NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things
// of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for
// x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So
@ -421,9 +425,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
// currently the only level 3 optimizer is the NhwcTransformer which is fully supported at runtime
if (!saving) {
#ifndef DISABLE_CONTRIB_OPS
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault);
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}
#else
ORT_UNUSED_PARAMETER(cpu_execution_provider);
#endif

View file

@ -2,18 +2,165 @@
// Licensed under the MIT License.
#include <deque>
#include "core/mlas/inc/mlas.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
using namespace onnx_layout_transformation;
using namespace nhwc_map_internal;
namespace onnxruntime {
static inline const OpTransformInfo*
NhwcConvLookup(
const OpTransformMap& conv_table,
const api::GraphRef& graph,
api::NodeRef& node) {
const auto& optype = node.OpType();
const auto& domain = node.Domain();
const auto inputs = node.Inputs();
if (inputs.empty()) {
// node with no input, can't be our transformation candidate.
return nullptr;
}
const auto info = graph.GetValueInfo(inputs[0]);
const api::DataType dtype = info->DType();
OpIdInfo key{optype, domain, dtype};
const auto iter = conv_table.find(key);
if (iter == conv_table.end()) {
return nullptr;
}
return &(iter->second);
}
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
if (!cpu_kernel_registry) {
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
return;
}
//
// Constructing a mapping table from operators to be transformed to their target.
// Make sure that the new nodes we are about to create during graph transformation,
// their kernels are available in the cpu EP.
//
{
// int8 qconv -> int8 nhwc qconv
OpKernelRegistryId qconv_int8{
"QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType<int8_t>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::INT8),
OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true});
conv_table_.emplace(
OpIdInfo("QLinearConv", kMSDomain, api::DataType::INT8),
OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true});
}
}
{
// uint8 qconv -> int8 nhwc qconv
OpKernelRegistryId qconv_uint8{
"QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType<uint8_t>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::UINT8),
OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true});
conv_table_.emplace(
OpIdInfo("QLinearConv", kMSDomain, api::DataType::UINT8),
OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true});
}
}
{
// fp16 conv -> fp16 nhwc conv
OpKernelRegistryId nhwc_conv_fp16{
"NhwcFusedConv", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("Conv", kOnnxDomain, api::DataType::FLOAT16),
OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false});
conv_table_.emplace(
OpIdInfo("FusedConv", kMSDomain, api::DataType::FLOAT16),
OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false});
}
}
{
// fp16 MaxPool -> fp16 nhwc MaxPool
OpKernelRegistryId nhwc_maxpool_fp16{
"MaxPool", kMSInternalNHWCDomain, 12, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("MaxPool", kOnnxDomain, api::DataType::FLOAT16),
OpTransformInfo{nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, nhwc_maxpool_fp16.version_, false});
}
}
{
// fp16 AveragePool -> fp16 nhwc AveragePool
OpKernelRegistryId nhwc_avgpool_fp16{
"AveragePool", kMSInternalNHWCDomain, 11, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("AveragePool", kOnnxDomain, api::DataType::FLOAT16),
OpTransformInfo{nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, nhwc_avgpool_fp16.version_, false});
}
}
{
// fp16 GlobalAveragePool -> fp16 nhwc GlobalAveragePool
OpKernelRegistryId nhwc_gavgpool_fp16{
"GlobalAveragePool", kMSInternalNHWCDomain, 1, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
const KernelCreateInfo* kernel_create_info{};
const auto status = cpu_kernel_registry->TryFindKernel(
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
OpIdInfo("GlobalAveragePool", kOnnxDomain, api::DataType::FLOAT16),
OpTransformInfo{nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, nhwc_gavgpool_fp16.version_, false});
}
}
};
Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
#if defined(ORT_MINIMAL_BUILD)
// update the producer/consumer info as previous optimizations may have invalidated it.
@ -36,41 +183,44 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
// Only QLinearConv needs to be handled explicitly. The rest will be transformed if needed during transpose
// optimization.
if (node->OpType() == "QLinearConv") {
auto domain = node->Domain();
// Skip if domain is incorrect
if (domain != kOnnxDomain && domain != kMSDomain) {
continue;
}
// Skip if already transformed
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
continue;
}
// Skip if unknown rank
auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape();
if (shape == nullptr) {
continue;
}
// Convert to channels last
size_t rank = shape->dim_size();
node->SetAttributeInt("channels_last", 1);
std::vector<int64_t> input_perm = ChannelFirstToLastPerm(rank);
std::vector<int64_t> output_perm = ChannelLastToFirstPerm(rank);
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
if (domain != kMSDomain) {
SwapNodeOpTypeDomainAndSinceVersion(*api_graph, *node, "QLinearConv", kMSDomain, 1);
}
modified = true;
// Only Conv and QLinearConv needs to be handled explicitly. The rest will be
// transformed if needed during transpose optimization.
const auto* transform = NhwcConvLookup(conv_table_, *api_graph, *node);
if (nullptr == transform) {
continue;
}
// Skip if already transformed
if (transform->has_channels_last_attrib_ &&
node->GetAttributeIntDefault("channels_last", 0) == 1) {
continue;
}
// Skip if unknown rank
auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape();
if (shape == nullptr) {
continue;
}
// Convert to channels last
if (transform->has_channels_last_attrib_) {
node->SetAttributeInt("channels_last", 1);
}
size_t rank = shape->dim_size();
std::vector<int64_t> input_perm = ChannelFirstToLastPerm(rank);
std::vector<int64_t> output_perm = ChannelLastToFirstPerm(rank);
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
// Replace the operator if needed
if (node->Domain() != transform->domain_ ||
node->OpType() != transform->optype_ ||
node->SinceVersion() != transform->version_) {
SwapNodeOpTypeDomainAndSinceVersion(
*api_graph, *node, transform->optype_,
transform->domain_, transform->version_);
}
modified = true;
}
if (modified) {

View file

@ -5,7 +5,64 @@
#include "core/common/common.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
//
// Data structures internal to nhwc transformer implementation.
// Maybe we should use Pimpl Idiom to hide all these into
// an implementation class. But it would add an extra pointer
// chasing during runtime.
//
namespace nhwc_map_internal {
/**
* @brief For identifying layout sensive operators
* as candidates for transforming to NHWC ops.
*/
struct OpIdInfo {
const std::string optype_;
const std::string domain_;
const onnx_layout_transformation::api::DataType data_type_;
OpIdInfo(
const std::basic_string_view<char>& op,
const std::basic_string_view<char>& domain,
onnx_layout_transformation::api::DataType data_type)
: optype_(op), domain_(domain), data_type_(data_type) {}
bool operator==(const OpIdInfo& other) const {
return optype_ == other.optype_ && domain_ == other.domain_ && data_type_ == other.data_type_;
}
};
/**
* @brief Hash function for \ref OpIdInfo
*/
class OpIdHash {
public:
size_t operator()(const OpIdInfo& op) const {
size_t h1 = std::hash<std::string>{}(op.optype_);
size_t h2 = std::hash<std::string>{}(op.domain_);
size_t h3 = size_t(op.data_type_);
return h2 ^ (h1 << 4) ^ (h3 << 16);
}
};
/**
* @brief Information needed for operator layout transformation
*/
struct OpTransformInfo {
const std::string optype_;
const std::string domain_;
const int version_;
const bool has_channels_last_attrib_;
};
using OpTransformMap = std::unordered_map<OpIdInfo, OpTransformInfo, OpIdHash>;
} // namespace nhwc_map_internal
namespace onnxruntime {
@ -17,14 +74,29 @@ and inserts nodes to transpose tensors as needed.
*/
class NhwcTransformer : public GraphTransformer {
private:
AllocatorPtr cpu_allocator_;
public:
explicit NhwcTransformer(AllocatorPtr cpu_allocator) noexcept
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)){};
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept;
/**
* @brief Usually called right after constructor, it shows whether
* this transformer should be used under current hardware configuration.
*
* @return whether this transformer would be useful under current hardware config
*/
bool IsActive() {
return !conv_table_.empty();
}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
AllocatorPtr cpu_allocator_;
/**
* A mapping table to identify operators that need to be transformed, and map
* them to the new operators that accept NHWC layout
*/
nhwc_map_internal::OpTransformMap conv_table_;
};
} // namespace onnxruntime

View file

@ -129,7 +129,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
weights_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8");
weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims());
weights_proto_u8.set_raw_data(w_temp.data<int8_t>(), w_temp.size());
weights_proto_u8.set_raw_data(w_temp.data<int8_t>(), static_cast<size_t>(w_temp.size()));
input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8);
ONNX_NAMESPACE::TensorProto weight_zp_proto_u8;
@ -140,7 +140,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8");
r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims());
r_proto_u8.set_raw_data(r_temp.data<int8_t>(), r_temp.size());
r_proto_u8.set_raw_data(r_temp.data<int8_t>(), static_cast<size_t>(r_temp.size()));
input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8);
ONNX_NAMESPACE::TensorProto r_zp_proto_u8;

View file

@ -172,6 +172,10 @@ std::optional<std::vector<int64_t>> ApiValueInfo::Shape() const {
api::DataType ApiValueInfo::DType() const {
const auto* type = node_arg_.TypeAsProto();
if (!type) {
return api::DataType::UNDEFINED;
}
if (!utils::HasTensorType(*type)) {
return api::DataType::UNDEFINED;
}

View file

@ -560,7 +560,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
static_cast<size_t>(output_count),
static_cast<size_t>(group_output_channels),
static_cast<size_t>(kernel_dim),
1, &gemm_params, thread_pool);
1, &gemm_params, nullptr);
}
}
};

View file

@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "graph_transform_test_builder.h"
#include "core/mlas/inc/mlas.h"
#include "core/graph/graph.h"
namespace onnxruntime {
@ -516,6 +516,167 @@ TEST(NhwcTransformerTests, ConvMixTensorRanks) {
TransformerLevel::Level3);
}
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
std::vector<MLFloat16> randomfp16(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
std::vector<MLFloat16> val(detail::SizeFromDims(shape));
float start = min.ToFloat();
float end = max.ToFloat();
float step = (end - start) / 128;
float value = start;
for (size_t i = 0; i < val.size(); ++i) {
value += step;
if (value > end) {
value = start;
}
val[i] = MLFloat16(value);
}
return val;
}
template <>
NodeArg* ModelTestBuilder::MakeInput<MLFloat16>(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
return MakeInput<MLFloat16>(shape, randomfp16(shape, min, max));
}
template <>
NodeArg* ModelTestBuilder::MakeInitializer(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
return MakeInitializer(shape, randomfp16(shape, min, max));
}
TEST(NhwcTransformerTests, ConvFp16) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* output_arg = builder.MakeOutput();
auto* weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
builder.AddConvNode(input_arg, weight_arg, output_arg);
};
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1);
EXPECT_EQ(op_to_count["Transpose"], 2);
};
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({1, 12, 37}, {32, 12, 5});
test_case({1, 23, 13, 13}, {30, 23, 3, 3});
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3});
}
TEST(NhwcTransformerTests, ConvMaxPoolFp16) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* conv_output_arg = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
auto* conv_weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
builder.AddConvNode(input_arg, conv_weight_arg, conv_output_arg);
Node& pool_node = builder.AddNode("MaxPool", {conv_output_arg}, {output_arg});
std::vector<int64_t> pads((weights_shape.size() - 2) * 2, 1);
pool_node.AddAttribute("pads", pads);
std::vector<int64_t> kernel_shape(weights_shape.size() - 2, 3);
pool_node.AddAttribute("kernel_shape", kernel_shape);
};
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1);
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.MaxPool"], 1);
EXPECT_EQ(op_to_count["Transpose"], 2);
};
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({5, 12, 37}, {128, 12, 5});
test_case({3, 14, 13, 13}, {64, 14, 3, 3});
test_case({1, 15, 11, 13, 15}, {31, 15, 5, 3, 3});
}
TEST(NhwcTransformerTests, ConvGlobalAveragePoolFp16) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* conv1_output_arg = builder.MakeIntermediate();
auto* conv2_output_arg = builder.MakeIntermediate();
auto* gavgpool1_output_arg = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 1, 1}, MLFloat16(-1.5f), MLFloat16(1.5f));
Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
builder.AddNode("GlobalAveragePool", {conv1_output_arg}, {gavgpool1_output_arg});
builder.AddConvNode(gavgpool1_output_arg, conv2_weight_arg, conv2_output_arg);
builder.AddNode("GlobalAveragePool", {conv2_output_arg}, {output_arg});
};
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2);
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.GlobalAveragePool"], 2);
EXPECT_EQ(op_to_count["Transpose"], 2);
};
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
}
TEST(NhwcTransformerTests, ConvAveragePoolFp16) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* conv1_output_arg = builder.MakeIntermediate();
auto* conv2_output_arg = builder.MakeIntermediate();
auto* avgpool1_output_arg = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
Node& avgpool_node1 = builder.AddNode(
"AveragePool", {conv1_output_arg}, {avgpool1_output_arg});
avgpool_node1.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
avgpool_node1.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
builder.AddConvNode(avgpool1_output_arg, conv2_weight_arg, conv2_output_arg);
Node& avgpool_node2 = builder.AddNode(
"AveragePool", {conv2_output_arg}, {output_arg});
avgpool_node2.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
avgpool_node2.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
};
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2);
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.AveragePool"], 2);
EXPECT_EQ(op_to_count["Transpose"], 2);
};
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
}
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
#endif // DISABLE_CONTRIB_OPS
} // namespace test