mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
NHWC graph optimizer (#15724)
### Description Augment nhwc graph optimizer to accommodate fp16 operators. ### Motivation and Context With new fp16 conv operator added. This operator prefers NHWC data layout. We need to augment existing graph optimizers to better utilize the new operator.
This commit is contained in:
parent
d35850c142
commit
0e9472d391
8 changed files with 464 additions and 46 deletions
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/graph_transformer_level.h"
|
||||
|
||||
|
|
@ -69,4 +70,27 @@ class GraphTransformer {
|
|||
const std::string name_;
|
||||
const InlinedHashSet<std::string_view> compatible_provider_types_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Immutable object to identify a kernel registration.
|
||||
*
|
||||
* This data structure is used by the graph transformers to check whether
|
||||
* a kernel is registered with the execution provider (i.e. has an
|
||||
* implementation). If not, the transformer can not generate a node with
|
||||
* such kernel.
|
||||
*/
|
||||
struct OpKernelRegistryId {
|
||||
const std::string op_type_;
|
||||
const std::string domain_;
|
||||
const int version_;
|
||||
const InlinedHashMap<std::string, MLDataType> type_constraints_;
|
||||
|
||||
OpKernelRegistryId(
|
||||
const std::basic_string_view<char>& op,
|
||||
const std::basic_string_view<char>& domain,
|
||||
const int version,
|
||||
std::initializer_list<std::pair<std::string, MLDataType>> init_list)
|
||||
: op_type_(op), domain_(domain), version_(version), type_constraints_(init_list) {}
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -350,7 +350,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<NchwcTransformer>());
|
||||
}
|
||||
auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault);
|
||||
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
// NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things
|
||||
// of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for
|
||||
// x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So
|
||||
|
|
@ -421,9 +425,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
|
|||
// currently the only level 3 optimizer is the NhwcTransformer which is fully supported at runtime
|
||||
if (!saving) {
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
|
||||
auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault);
|
||||
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
|
||||
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
|
||||
auto nhwc_transformer = std::make_unique<NhwcTransformer>(std::move(cpu_allocator), std::move(cpu_registry));
|
||||
if (nhwc_transformer->IsActive()) {
|
||||
transformers.emplace_back(std::move(nhwc_transformer));
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(cpu_execution_provider);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -2,18 +2,165 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <deque>
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/nhwc_transformer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace onnx_layout_transformation;
|
||||
using namespace nhwc_map_internal;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
static inline const OpTransformInfo*
|
||||
NhwcConvLookup(
|
||||
const OpTransformMap& conv_table,
|
||||
const api::GraphRef& graph,
|
||||
api::NodeRef& node) {
|
||||
const auto& optype = node.OpType();
|
||||
const auto& domain = node.Domain();
|
||||
const auto inputs = node.Inputs();
|
||||
if (inputs.empty()) {
|
||||
// node with no input, can't be our transformation candidate.
|
||||
return nullptr;
|
||||
}
|
||||
const auto info = graph.GetValueInfo(inputs[0]);
|
||||
const api::DataType dtype = info->DType();
|
||||
OpIdInfo key{optype, domain, dtype};
|
||||
|
||||
const auto iter = conv_table.find(key);
|
||||
if (iter == conv_table.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &(iter->second);
|
||||
}
|
||||
|
||||
NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept
|
||||
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
|
||||
if (!cpu_kernel_registry) {
|
||||
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Constructing a mapping table from operators to be transformed to their target.
|
||||
// Make sure that the new nodes we are about to create during graph transformation,
|
||||
// their kernels are available in the cpu EP.
|
||||
//
|
||||
|
||||
{
|
||||
// int8 qconv -> int8 nhwc qconv
|
||||
OpKernelRegistryId qconv_int8{
|
||||
"QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType<int8_t>()}}}};
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
|
||||
qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::INT8),
|
||||
OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true});
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("QLinearConv", kMSDomain, api::DataType::INT8),
|
||||
OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// uint8 qconv -> int8 nhwc qconv
|
||||
OpKernelRegistryId qconv_uint8{
|
||||
"QLinearConv", kMSDomain, 1, {{"T1", {DataTypeImpl::GetTensorType<uint8_t>()}}}};
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
|
||||
qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::UINT8),
|
||||
OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true});
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("QLinearConv", kMSDomain, api::DataType::UINT8),
|
||||
OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// fp16 conv -> fp16 nhwc conv
|
||||
OpKernelRegistryId nhwc_conv_fp16{
|
||||
"NhwcFusedConv", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
|
||||
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
|
||||
nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("Conv", kOnnxDomain, api::DataType::FLOAT16),
|
||||
OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false});
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("FusedConv", kMSDomain, api::DataType::FLOAT16),
|
||||
OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// fp16 MaxPool -> fp16 nhwc MaxPool
|
||||
OpKernelRegistryId nhwc_maxpool_fp16{
|
||||
"MaxPool", kMSInternalNHWCDomain, 12, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
|
||||
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
|
||||
nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("MaxPool", kOnnxDomain, api::DataType::FLOAT16),
|
||||
OpTransformInfo{nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, nhwc_maxpool_fp16.version_, false});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// fp16 AveragePool -> fp16 nhwc AveragePool
|
||||
OpKernelRegistryId nhwc_avgpool_fp16{
|
||||
"AveragePool", kMSInternalNHWCDomain, 11, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
|
||||
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
|
||||
nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("AveragePool", kOnnxDomain, api::DataType::FLOAT16),
|
||||
OpTransformInfo{nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, nhwc_avgpool_fp16.version_, false});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// fp16 GlobalAveragePool -> fp16 nhwc GlobalAveragePool
|
||||
OpKernelRegistryId nhwc_gavgpool_fp16{
|
||||
"GlobalAveragePool", kMSInternalNHWCDomain, 1, {{"T", {DataTypeImpl::GetTensorType<MLFloat16>()}}}};
|
||||
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto status = cpu_kernel_registry->TryFindKernel(
|
||||
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
|
||||
nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
|
||||
if (status.IsOK() && kernel_create_info != nullptr) {
|
||||
kernel_create_info = nullptr;
|
||||
conv_table_.emplace(
|
||||
OpIdInfo("GlobalAveragePool", kOnnxDomain, api::DataType::FLOAT16),
|
||||
OpTransformInfo{nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, nhwc_gavgpool_fp16.version_, false});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
#if defined(ORT_MINIMAL_BUILD)
|
||||
// update the producer/consumer info as previous optimizations may have invalidated it.
|
||||
|
|
@ -36,41 +183,44 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
|
||||
// Only QLinearConv needs to be handled explicitly. The rest will be transformed if needed during transpose
|
||||
// optimization.
|
||||
if (node->OpType() == "QLinearConv") {
|
||||
auto domain = node->Domain();
|
||||
|
||||
// Skip if domain is incorrect
|
||||
if (domain != kOnnxDomain && domain != kMSDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if already transformed
|
||||
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if unknown rank
|
||||
auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape();
|
||||
if (shape == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert to channels last
|
||||
size_t rank = shape->dim_size();
|
||||
node->SetAttributeInt("channels_last", 1);
|
||||
|
||||
std::vector<int64_t> input_perm = ChannelFirstToLastPerm(rank);
|
||||
std::vector<int64_t> output_perm = ChannelLastToFirstPerm(rank);
|
||||
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
|
||||
|
||||
if (domain != kMSDomain) {
|
||||
SwapNodeOpTypeDomainAndSinceVersion(*api_graph, *node, "QLinearConv", kMSDomain, 1);
|
||||
}
|
||||
|
||||
modified = true;
|
||||
// Only Conv and QLinearConv needs to be handled explicitly. The rest will be
|
||||
// transformed if needed during transpose optimization.
|
||||
const auto* transform = NhwcConvLookup(conv_table_, *api_graph, *node);
|
||||
if (nullptr == transform) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if already transformed
|
||||
if (transform->has_channels_last_attrib_ &&
|
||||
node->GetAttributeIntDefault("channels_last", 0) == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if unknown rank
|
||||
auto shape = NodeFromApiNode(*node).InputDefs()[0]->Shape();
|
||||
if (shape == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert to channels last
|
||||
if (transform->has_channels_last_attrib_) {
|
||||
node->SetAttributeInt("channels_last", 1);
|
||||
}
|
||||
size_t rank = shape->dim_size();
|
||||
std::vector<int64_t> input_perm = ChannelFirstToLastPerm(rank);
|
||||
std::vector<int64_t> output_perm = ChannelLastToFirstPerm(rank);
|
||||
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
|
||||
|
||||
// Replace the operator if needed
|
||||
if (node->Domain() != transform->domain_ ||
|
||||
node->OpType() != transform->optype_ ||
|
||||
node->SinceVersion() != transform->version_) {
|
||||
SwapNodeOpTypeDomainAndSinceVersion(
|
||||
*api_graph, *node, transform->optype_,
|
||||
transform->domain_, transform->version_);
|
||||
}
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,64 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_utils.h"
|
||||
|
||||
//
|
||||
// Data structures internal to nhwc transformer implementation.
|
||||
// Maybe we should use Pimpl Idiom to hide all these into
|
||||
// an implementation class. But it would add an extra pointer
|
||||
// chasing during runtime.
|
||||
//
|
||||
namespace nhwc_map_internal {
|
||||
|
||||
/**
|
||||
* @brief For identifying layout sensive operators
|
||||
* as candidates for transforming to NHWC ops.
|
||||
*/
|
||||
struct OpIdInfo {
|
||||
const std::string optype_;
|
||||
const std::string domain_;
|
||||
const onnx_layout_transformation::api::DataType data_type_;
|
||||
|
||||
OpIdInfo(
|
||||
const std::basic_string_view<char>& op,
|
||||
const std::basic_string_view<char>& domain,
|
||||
onnx_layout_transformation::api::DataType data_type)
|
||||
: optype_(op), domain_(domain), data_type_(data_type) {}
|
||||
|
||||
bool operator==(const OpIdInfo& other) const {
|
||||
return optype_ == other.optype_ && domain_ == other.domain_ && data_type_ == other.data_type_;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Hash function for \ref OpIdInfo
|
||||
*/
|
||||
class OpIdHash {
|
||||
public:
|
||||
size_t operator()(const OpIdInfo& op) const {
|
||||
size_t h1 = std::hash<std::string>{}(op.optype_);
|
||||
size_t h2 = std::hash<std::string>{}(op.domain_);
|
||||
size_t h3 = size_t(op.data_type_);
|
||||
return h2 ^ (h1 << 4) ^ (h3 << 16);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Information needed for operator layout transformation
|
||||
*/
|
||||
struct OpTransformInfo {
|
||||
const std::string optype_;
|
||||
const std::string domain_;
|
||||
const int version_;
|
||||
const bool has_channels_last_attrib_;
|
||||
};
|
||||
|
||||
using OpTransformMap = std::unordered_map<OpIdInfo, OpTransformInfo, OpIdHash>;
|
||||
|
||||
} // namespace nhwc_map_internal
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -17,14 +74,29 @@ and inserts nodes to transpose tensors as needed.
|
|||
*/
|
||||
class NhwcTransformer : public GraphTransformer {
|
||||
private:
|
||||
AllocatorPtr cpu_allocator_;
|
||||
|
||||
public:
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator) noexcept
|
||||
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)){};
|
||||
explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr<KernelRegistry> cpu_kernel_registry) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Usually called right after constructor, it shows whether
|
||||
* this transformer should be used under current hardware configuration.
|
||||
*
|
||||
* @return whether this transformer would be useful under current hardware config
|
||||
*/
|
||||
bool IsActive() {
|
||||
return !conv_table_.empty();
|
||||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
AllocatorPtr cpu_allocator_;
|
||||
|
||||
/**
|
||||
* A mapping table to identify operators that need to be transformed, and map
|
||||
* them to the new operators that accept NHWC layout
|
||||
*/
|
||||
nhwc_map_internal::OpTransformMap conv_table_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
weights_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
|
||||
weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8");
|
||||
weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims());
|
||||
weights_proto_u8.set_raw_data(w_temp.data<int8_t>(), w_temp.size());
|
||||
weights_proto_u8.set_raw_data(w_temp.data<int8_t>(), static_cast<size_t>(w_temp.size()));
|
||||
input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto weight_zp_proto_u8;
|
||||
|
|
@ -140,7 +140,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
|
|||
r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
|
||||
r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8");
|
||||
r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims());
|
||||
r_proto_u8.set_raw_data(r_temp.data<int8_t>(), r_temp.size());
|
||||
r_proto_u8.set_raw_data(r_temp.data<int8_t>(), static_cast<size_t>(r_temp.size()));
|
||||
input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto r_zp_proto_u8;
|
||||
|
|
|
|||
|
|
@ -172,6 +172,10 @@ std::optional<std::vector<int64_t>> ApiValueInfo::Shape() const {
|
|||
|
||||
api::DataType ApiValueInfo::DType() const {
|
||||
const auto* type = node_arg_.TypeAsProto();
|
||||
if (!type) {
|
||||
return api::DataType::UNDEFINED;
|
||||
}
|
||||
|
||||
if (!utils::HasTensorType(*type)) {
|
||||
return api::DataType::UNDEFINED;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -560,7 +560,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
|
|||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(group_output_channels),
|
||||
static_cast<size_t>(kernel_dim),
|
||||
1, &gemm_params, thread_pool);
|
||||
1, &gemm_params, nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "graph_transform_test_builder.h"
|
||||
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -516,6 +516,167 @@ TEST(NhwcTransformerTests, ConvMixTensorRanks) {
|
|||
TransformerLevel::Level3);
|
||||
}
|
||||
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
|
||||
std::vector<MLFloat16> randomfp16(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
|
||||
std::vector<MLFloat16> val(detail::SizeFromDims(shape));
|
||||
float start = min.ToFloat();
|
||||
float end = max.ToFloat();
|
||||
float step = (end - start) / 128;
|
||||
float value = start;
|
||||
for (size_t i = 0; i < val.size(); ++i) {
|
||||
value += step;
|
||||
if (value > end) {
|
||||
value = start;
|
||||
}
|
||||
val[i] = MLFloat16(value);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
NodeArg* ModelTestBuilder::MakeInput<MLFloat16>(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
|
||||
return MakeInput<MLFloat16>(shape, randomfp16(shape, min, max));
|
||||
}
|
||||
|
||||
template <>
|
||||
NodeArg* ModelTestBuilder::MakeInitializer(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
|
||||
return MakeInitializer(shape, randomfp16(shape, min, max));
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvFp16) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
auto* weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
|
||||
builder.AddConvNode(input_arg, weight_arg, output_arg);
|
||||
};
|
||||
|
||||
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 2);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({1, 12, 37}, {32, 12, 5});
|
||||
test_case({1, 23, 13, 13}, {30, 23, 3, 3});
|
||||
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3});
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvMaxPoolFp16) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* conv_output_arg = builder.MakeIntermediate();
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
auto* conv_weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
|
||||
builder.AddConvNode(input_arg, conv_weight_arg, conv_output_arg);
|
||||
Node& pool_node = builder.AddNode("MaxPool", {conv_output_arg}, {output_arg});
|
||||
std::vector<int64_t> pads((weights_shape.size() - 2) * 2, 1);
|
||||
pool_node.AddAttribute("pads", pads);
|
||||
std::vector<int64_t> kernel_shape(weights_shape.size() - 2, 3);
|
||||
pool_node.AddAttribute("kernel_shape", kernel_shape);
|
||||
};
|
||||
|
||||
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 1);
|
||||
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.MaxPool"], 1);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 2);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({5, 12, 37}, {128, 12, 5});
|
||||
test_case({3, 14, 13, 13}, {64, 14, 3, 3});
|
||||
test_case({1, 15, 11, 13, 15}, {31, 15, 5, 3, 3});
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvGlobalAveragePoolFp16) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* conv1_output_arg = builder.MakeIntermediate();
|
||||
auto* conv2_output_arg = builder.MakeIntermediate();
|
||||
auto* gavgpool1_output_arg = builder.MakeIntermediate();
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 1, 1}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
|
||||
Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
|
||||
conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
|
||||
builder.AddNode("GlobalAveragePool", {conv1_output_arg}, {gavgpool1_output_arg});
|
||||
builder.AddConvNode(gavgpool1_output_arg, conv2_weight_arg, conv2_output_arg);
|
||||
builder.AddNode("GlobalAveragePool", {conv2_output_arg}, {output_arg});
|
||||
};
|
||||
|
||||
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2);
|
||||
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.GlobalAveragePool"], 2);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 2);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
}
|
||||
|
||||
TEST(NhwcTransformerTests, ConvAveragePoolFp16) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* conv1_output_arg = builder.MakeIntermediate();
|
||||
auto* conv2_output_arg = builder.MakeIntermediate();
|
||||
auto* avgpool1_output_arg = builder.MakeIntermediate();
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
|
||||
|
||||
Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
|
||||
conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
Node& avgpool_node1 = builder.AddNode(
|
||||
"AveragePool", {conv1_output_arg}, {avgpool1_output_arg});
|
||||
avgpool_node1.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
|
||||
avgpool_node1.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
|
||||
builder.AddConvNode(avgpool1_output_arg, conv2_weight_arg, conv2_output_arg);
|
||||
Node& avgpool_node2 = builder.AddNode(
|
||||
"AveragePool", {conv2_output_arg}, {output_arg});
|
||||
avgpool_node2.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
|
||||
avgpool_node2.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
};
|
||||
|
||||
auto check_nhwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 2);
|
||||
EXPECT_EQ(op_to_count["com.ms.internal.nhwc.AveragePool"], 2);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 2);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
}
|
||||
|
||||
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
|
||||
#endif // DISABLE_CONTRIB_OPS
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue