mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Make session configuration options available to kernels via OpKernelInfo (#18897)
### Description
<!-- Describe your changes. -->
Pass through the ConfigOptions from the session via OpKernelInfo so that
kernel behavior can be configured.
Initial usage would be to optionally enable a fast path for ARM64 bloat16 GEMM - see #17031
Other usages could be things like selected the exact implementations of the activation functions for RNN operators instead of the default approximations (e.g. use [sigmoid_exact instead of sigmoid](2d6e2e243d/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h (L379-L382)))
OpKernelInfo is already passing through things from the session state, and adding a new member of ConfigOptions
is the simpler update. It's also a more natural fit given it's providing state/info to the kernel.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
a503561d0c
commit
8f2e57f5d0
26 changed files with 246 additions and 162 deletions
|
|
@ -28,7 +28,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
|
||||
const OrtValueNameIdxMap& mlvalue_name_idx_map,
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const AllocatorMap& allocators = {});
|
||||
const AllocatorMap& allocators,
|
||||
const ConfigOptions& config_options);
|
||||
|
||||
OpKernelInfo(const OpKernelInfo& other);
|
||||
|
||||
|
|
@ -50,6 +51,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
|
||||
const AllocatorMap& GetAllocators() const { return allocators_; }
|
||||
|
||||
const ConfigOptions& GetConfigOptions() const { return config_options_; }
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_MOVE(OpKernelInfo);
|
||||
ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
|
||||
|
|
@ -64,6 +67,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
const DataTransferManager& data_transfer_mgr_;
|
||||
ProtoHelperNodeContext proto_helper_context_;
|
||||
const AllocatorMap& allocators_;
|
||||
const ConfigOptions& config_options_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ Status KernelRegistryManager::CreateKernel(const Node& node,
|
|||
session_state.GetConstantInitializedTensors(),
|
||||
session_state.GetOrtValueNameIdxMap(),
|
||||
session_state.GetDataTransferMgr(),
|
||||
session_state.GetAllocators());
|
||||
session_state.GetAllocators(),
|
||||
session_state.GetSessionOptions().config_options);
|
||||
|
||||
return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
|
|||
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const AllocatorMap& allocators)
|
||||
const AllocatorMap& allocators,
|
||||
const ConfigOptions& config_options)
|
||||
: OpNodeProtoHelper(&proto_helper_context_),
|
||||
node_(node),
|
||||
kernel_def_(kernel_def),
|
||||
|
|
@ -24,15 +25,22 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
|
|||
ort_value_name_idx_map_(ort_value_name_idx_map),
|
||||
data_transfer_mgr_(data_transfer_mgr),
|
||||
proto_helper_context_(node),
|
||||
allocators_(allocators) {}
|
||||
allocators_(allocators),
|
||||
config_options_(config_options) {
|
||||
}
|
||||
|
||||
OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
|
||||
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_,
|
||||
other.ort_value_name_idx_map_, other.data_transfer_mgr_, other.allocators_) {}
|
||||
other.ort_value_name_idx_map_, other.data_transfer_mgr_,
|
||||
other.allocators_, other.config_options_) {
|
||||
}
|
||||
|
||||
AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const {
|
||||
auto it = allocators_.find(execution_provider_->GetOrtDeviceByMemType(mem_type));
|
||||
if (it != allocators_.end()) return it->second;
|
||||
if (it != allocators_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@ namespace onnxruntime {
|
|||
|
||||
ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
|
||||
bool skip_dequantize_linear,
|
||||
const ConfigOptions& config_options,
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers,
|
||||
const InlinedHashSet<std::string>& excluded_initializers) noexcept
|
||||
: GraphTransformer("ConstantFolding", compatible_execution_providers),
|
||||
skip_dequantize_linear_(skip_dequantize_linear),
|
||||
config_options_(config_options),
|
||||
excluded_initializers_(excluded_initializers),
|
||||
execution_provider_(execution_provider) {
|
||||
}
|
||||
|
|
@ -250,12 +252,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// override the EP assigned to the node so that it will use the CPU kernel for Compute.
|
||||
node->SetExecutionProviderType(kCpuExecutionProvider);
|
||||
|
||||
kernel = info.CreateKernel(node);
|
||||
kernel = info.CreateKernel(node, config_options_);
|
||||
|
||||
// undo the EP change to the value that was assigned at graph partitioning time
|
||||
node->SetExecutionProviderType(ep_type);
|
||||
} else {
|
||||
kernel = info.CreateKernel(node);
|
||||
kernel = info.CreateKernel(node, config_options_);
|
||||
}
|
||||
|
||||
// We currently constant fold using the CPU EP only.
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class ConstantFolding : public GraphTransformer {
|
|||
*/
|
||||
ConstantFolding(const IExecutionProvider& execution_provider,
|
||||
bool skip_dequantize_linear,
|
||||
const ConfigOptions& config_options,
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ class ConstantFolding : public GraphTransformer {
|
|||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool skip_dequantize_linear_;
|
||||
const ConfigOptions& config_options_;
|
||||
const InlinedHashSet<std::string> excluded_initializers_;
|
||||
const IExecutionProvider& execution_provider_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -223,7 +223,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));
|
||||
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
|
||||
session_options.config_options));
|
||||
transformers.emplace_back(std::make_unique<MatMulAddFusion>());
|
||||
transformers.emplace_back(std::make_unique<ReshapeFusion>());
|
||||
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
|
||||
|
|
|
|||
|
|
@ -128,26 +128,34 @@ static Status TryCreateKernel(const Node& node,
|
|||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
FuncManager& funcs_mgr,
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const ConfigOptions& config_options,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
|
||||
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
|
||||
&kernel_create_info));
|
||||
|
||||
static const AllocatorMap dummy_allocators;
|
||||
|
||||
OpKernelInfo kernel_info(node,
|
||||
*kernel_create_info->kernel_def,
|
||||
execution_provider,
|
||||
constant_initialized_tensors,
|
||||
ort_value_name_idx_map,
|
||||
data_transfer_mgr);
|
||||
data_transfer_mgr,
|
||||
dummy_allocators,
|
||||
config_options);
|
||||
|
||||
return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel);
|
||||
}
|
||||
|
||||
std::unique_ptr<const OpKernel> OptimizerExecutionFrame::Info::CreateKernel(const Node* node) const {
|
||||
std::unique_ptr<const OpKernel>
|
||||
OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOptions& config_options) const {
|
||||
std::unique_ptr<OpKernel> op_kernel;
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
|
||||
FuncManager func;
|
||||
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_,
|
||||
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
|
||||
op_kernel);
|
||||
|
||||
// Kernel found in the CPU kernel registry
|
||||
|
|
|
|||
|
|
@ -27,11 +27,13 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
const Path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
|
||||
Info(const std::vector<const Node*>& nodes,
|
||||
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
|
||||
const Path& model_path,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
|
||||
|
||||
~Info() = default;
|
||||
|
||||
const AllocatorPtr& GetAllocator() const {
|
||||
|
|
@ -52,7 +54,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
|
|||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<const OpKernel> CreateKernel(const Node* node) const;
|
||||
std::unique_ptr<const OpKernel> CreateKernel(const Node* node, const ConfigOptions& config_options) const;
|
||||
|
||||
// Check if an kernel create info can be found in the registry.
|
||||
Status TryFindKernel(const Node* node, const KernelCreateInfo** out) const;
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ struct Logger;
|
|||
struct Capture;
|
||||
} // namespace logging
|
||||
struct ComputeCapability;
|
||||
struct ConfigOptions;
|
||||
struct DataTransferManager;
|
||||
struct IndexedSubGraph;
|
||||
struct IndexedSubGraph_MetaDef;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <optional>
|
||||
|
||||
// Public wrappers around internal ort interfaces (currently)
|
||||
#include "core/providers/shared_library/provider_host_api.h"
|
||||
|
||||
|
|
@ -426,6 +428,9 @@ struct ProviderHost {
|
|||
|
||||
virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;
|
||||
|
||||
// ConfigOptions
|
||||
virtual std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0;
|
||||
|
||||
// ComputeCapability
|
||||
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
|
||||
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
|
||||
|
|
@ -808,6 +813,7 @@ struct ProviderHost {
|
|||
virtual uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) = 0;
|
||||
virtual uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) = 0;
|
||||
virtual const Node& OpKernelInfo__node(const OpKernelInfo* p) = 0;
|
||||
virtual const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) = 0;
|
||||
|
||||
// SessionState
|
||||
virtual const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) = 0;
|
||||
|
|
|
|||
|
|
@ -335,6 +335,14 @@ struct DataTypeUtils final {
|
|||
|
||||
} // namespace Utils
|
||||
|
||||
struct ConfigOptions final {
|
||||
std::optional<std::string> GetConfigEntry(const std::string& config_key) const {
|
||||
return g_host->ConfigOptions__GetConfigEntry(this, config_key);
|
||||
}
|
||||
|
||||
PROVIDER_DISALLOW_ALL(ConfigOptions)
|
||||
};
|
||||
|
||||
struct ComputeCapability final {
|
||||
static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
|
||||
static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
|
||||
|
|
@ -901,6 +909,8 @@ struct OpKernelInfo final {
|
|||
|
||||
const Node& node() const noexcept { return g_host->OpKernelInfo__node(this); }
|
||||
|
||||
const ConfigOptions& GetConfigOptions() const { return g_host->OpKernelInfo__GetConfigOptions(this); }
|
||||
|
||||
OpKernelInfo() = delete;
|
||||
OpKernelInfo(const OpKernelInfo&) = delete;
|
||||
void operator=(const OpKernelInfo&) = delete;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/framework/allocator_utils.h"
|
||||
#include "core/framework/config_options.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
|
|
@ -529,6 +530,11 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; }
|
||||
|
||||
// ConfigOptions (wrapped)
|
||||
std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override {
|
||||
return p->GetConfigEntry(config_key);
|
||||
}
|
||||
|
||||
// ComputeCapability (wrapped)
|
||||
std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
|
||||
void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }
|
||||
|
|
@ -934,6 +940,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) override { return p->GetInputCount(); }
|
||||
uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) override { return p->GetOutputCount(); }
|
||||
const Node& OpKernelInfo__node(const OpKernelInfo* p) override { return p->node(); }
|
||||
const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) override { return p->GetConfigOptions(); }
|
||||
|
||||
// SessionState (wrapped)
|
||||
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }
|
||||
|
|
|
|||
|
|
@ -421,7 +421,10 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info,
|
|||
static const OrtValueNameIdxMap kEmptyNameMap;
|
||||
|
||||
OpKernelInfo tmp_kernel_info(*node_ptr.get(), *kernel_def, *ep, kEmptyValueMap, kEmptyNameMap,
|
||||
kernel_info->GetDataTransferManager(), kernel_info->GetAllocators());
|
||||
kernel_info->GetDataTransferManager(),
|
||||
kernel_info->GetAllocators(),
|
||||
kernel_info->GetConfigOptions());
|
||||
|
||||
std::unique_ptr<onnxruntime::OpKernel> op_kernel;
|
||||
|
||||
auto& node_repo = NodeRepo::GetInstance();
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ class PlannerTest : public ::testing::Test {
|
|||
ASSERT_NE(ep, nullptr);
|
||||
auto info = std::make_unique<OpKernelInfo>(
|
||||
*p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(),
|
||||
state_->GetDataTransferMgr());
|
||||
state_->GetDataTransferMgr(), state_->GetAllocators(), state_->GetSessionOptions().config_options);
|
||||
|
||||
op_kernel_infos_.push_back(std::move(info));
|
||||
const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{};
|
||||
|
|
|
|||
|
|
@ -82,6 +82,11 @@ ProviderInfo_ROCM& GetProviderInfo_ROCM();
|
|||
class FuseAdd : public OpKernel {
|
||||
public:
|
||||
explicit FuseAdd(const OpKernelInfo& info) : OpKernel(info) {
|
||||
// logic for testing that a session options config value can be read here
|
||||
auto test_throw_in_ctor = info.GetConfigOptions().GetConfigEntry("ThrowInKernelCtor");
|
||||
if (test_throw_in_ctor == "1") {
|
||||
ORT_THROW("Test exception in ctor");
|
||||
};
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
|
|
@ -96,6 +101,7 @@ class FuseAdd : public OpKernel {
|
|||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
constexpr const char* kFuseTest = "FuseTest";
|
||||
constexpr const char* kFuseExecutionProvider = "FuseExecutionProvider";
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kFuseExecutionProvider, kFuseTest, 1, FuseAdd);
|
||||
|
|
@ -1263,28 +1269,22 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
|
|||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
// required, optional and invalid input
|
||||
status = RunOptionalInputTest(true, true, true, version, sess_env);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(true, true, true, version, sess_env),
|
||||
"Invalid input name");
|
||||
|
||||
// missing required
|
||||
status = RunOptionalInputTest(false, true, false, version, sess_env);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
if (version == 3) {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
|
||||
} else {
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:"));
|
||||
}
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(false, true, false, version, sess_env),
|
||||
(version == 3 ? "Invalid input name" : "Missing Input:"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionTest) {
|
||||
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
static void CreateFuseOpModel(const std::string& model_file_name) {
|
||||
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor.
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
|
|
@ -1307,18 +1307,19 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
ASSERT_STATUS_OK(graph.Resolve());
|
||||
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, FunctionTest) {
|
||||
std::string model_file_name = "execution_provider_test_graph.onnx";
|
||||
status = onnxruntime::Model::Save(model, model_file_name);
|
||||
CreateFuseOpModel(model_file_name);
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "ExecutionProviderTest.FunctionTest";
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
status = session_object.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
InferenceSession session{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
|
@ -1329,11 +1330,14 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
std::vector<int64_t> dims_mul_x = {3, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
|
||||
&ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_y);
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
|
||||
&ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_z);
|
||||
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
|
||||
&ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
|
|
@ -1349,67 +1353,33 @@ TEST(ExecutionProviderTest, FunctionTest) {
|
|||
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
|
||||
|
||||
// Now run
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
|
||||
InferenceSession session_object_2{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(
|
||||
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
ASSERT_STATUS_OK(session_object_2.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object_2.Initialize());
|
||||
ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches));
|
||||
InferenceSession session2{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
ASSERT_STATUS_OK(session2.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session2.Initialize());
|
||||
ASSERT_STATUS_OK(session2.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
|
||||
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor.
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
|
||||
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
inputs.push_back(&input_arg_2);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
|
||||
outputs.push_back(&output_arg);
|
||||
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
|
||||
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
std::string model_file_name = "fused_node_shape_inference_test_graph.onnx";
|
||||
status = onnxruntime::Model::Save(model, model_file_name);
|
||||
|
||||
CreateFuseOpModel(model_file_name);
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest";
|
||||
InferenceSessionWrapper session{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(
|
||||
session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
status = session.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
ASSERT_STATUS_OK(session.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
Graph& fused_graph = session.GetMutableGraph();
|
||||
ASSERT_TRUE(fused_graph.NumberOfNodes() == 1);
|
||||
ASSERT_EQ(fused_graph.NumberOfNodes(), 1);
|
||||
auto& fused_node = *fused_graph.Nodes().begin();
|
||||
ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused);
|
||||
ASSERT_EQ(fused_node.NodeType(), Node::Type::Fused);
|
||||
ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function());
|
||||
|
||||
// Clear shape inference data from output node to verify that assigned inference function is called
|
||||
|
|
@ -1419,7 +1389,25 @@ TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
|
|||
ASSERT_STATUS_OK(fused_graph.Resolve());
|
||||
|
||||
ASSERT_TRUE(fused_node_output.Shape() != nullptr);
|
||||
ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()) == utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape()));
|
||||
ASSERT_EQ(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()), TensorShape({3, 2}));
|
||||
}
|
||||
|
||||
TEST(ExecutionProviderTest, OpKernelInfoCanReadConfigOptions) {
|
||||
std::string model_file_name = "OpKernelInfoCanReadConfigOptions.onnx";
|
||||
CreateFuseOpModel(model_file_name);
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "ExecutionProviderTest.OpKernelInfoCanReadConfigOptions";
|
||||
|
||||
// add a config key that if read causes the Fuse op kernel to throw in the ctor. this is just to test the value is passed
|
||||
// through in the simplest way, as the kernel is constructed in InferenceSession::Intialize so we don't need to
|
||||
// actually run the model.
|
||||
ASSERT_STATUS_OK(so.config_options.AddConfigEntry("ThrowInKernelCtor", "1"));
|
||||
|
||||
InferenceSession session{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
|
||||
ASSERT_STATUS_OK(session.Load(model_file_name));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Initialize(), "Test exception in ctor");
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
|
||||
|
|
|
|||
|
|
@ -84,9 +84,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
|
|||
auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
|
||||
OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(),
|
||||
s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr());
|
||||
unique_ptr<TestOpKernel> p_kernel;
|
||||
p_kernel.reset(new TestOpKernel(p_info));
|
||||
s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr(), s.GetAllocators(),
|
||||
s.GetSessionOptions().config_options);
|
||||
|
||||
std::unique_ptr<TestOpKernel> p_kernel = std::make_unique<TestOpKernel>(p_info);
|
||||
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
|
||||
std::cout << "node_idx: " << node.Index() << std::endl;
|
||||
|
||||
|
|
|
|||
|
|
@ -1503,10 +1503,8 @@ TEST_F(GraphTest, ShapeInferenceErrorHandling) {
|
|||
|
||||
graph.AddNode("node_1", "ShapeInferenceThrowsOp", "node 1", {&input_arg1}, {&output_arg1});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_FALSE(status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Node (node_1) Op (ShapeInferenceThrowsOp) "
|
||||
"[ShapeInferenceError] try harder"));
|
||||
EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(graph.Resolve(),
|
||||
"Node (node_1) Op (ShapeInferenceThrowsOp) [ShapeInferenceError] try harder");
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, AddTensorAttribute) {
|
||||
|
|
@ -2024,10 +2022,9 @@ TEST_F(GraphTest, LoadModelMissingInput) {
|
|||
SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2});
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
Status st = Model::Load(std::move(m), model, nullptr, *logger_);
|
||||
ASSERT_FALSE(st.IsOK());
|
||||
ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, "
|
||||
"initializer, or output of a previous node."));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(Model::Load(std::move(m), model, nullptr, *logger_),
|
||||
"Invalid model. Node input 'y' is not a graph input, "
|
||||
"initializer, or output of a previous node.");
|
||||
}
|
||||
|
||||
// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph.
|
||||
|
|
|
|||
|
|
@ -69,7 +69,18 @@ struct KernelAndDef {
|
|||
.SetDomain(domain)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.Build();
|
||||
OpKernelInfo info(main_node, *out.def, *out.a, {}, {}, {});
|
||||
|
||||
// these usually come from the session state. OpKernelInfo stores references to them so we need a valid backing
|
||||
// instance even though we don't use them in this test.
|
||||
static const std::unordered_map<int, OrtValue> constant_initialized_tensors;
|
||||
static const OrtValueNameIdxMap mlvalue_name_idx_map;
|
||||
static const DataTransferManager data_transfer_mgr;
|
||||
static const AllocatorMap allocators;
|
||||
static const ConfigOptions config_options;
|
||||
OpKernelInfo info(main_node, *out.def, *out.a,
|
||||
constant_initialized_tensors, mlvalue_name_idx_map, data_transfer_mgr, allocators,
|
||||
config_options);
|
||||
|
||||
out.kernel = std::make_unique<KernelType>(info);
|
||||
return out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/optimizer/common_subexpression_elimination.h"
|
||||
#include "core/optimizer/graph_transformer_mgr.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "orttraining/core/optimizer/graph_transformer_utils.h"
|
||||
|
|
@ -272,20 +273,21 @@ TEST(CseTests, MergedValueAndGraphOutputAreOutputsOfSameNode) {
|
|||
TEST(CseTests, MergeConstants) {
|
||||
auto model_uri = ORT_TSTR("testdata/transform/cse/cse_merge_constants.onnx");
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger())
|
||||
.IsOK());
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
GraphTransformerManager graph_transformation_mgr(1);
|
||||
// In current implementation, equal constants are not merged. So CSE must precede constant folding, otherwise we end up
|
||||
// with multiple copies of the same constant.
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
ASSERT_TRUE(
|
||||
graph_transformation_mgr.Register(std::make_unique<CommonSubexpressionElimination>(), TransformerLevel::Level1).IsOK());
|
||||
ASSERT_TRUE(
|
||||
graph_transformation_mgr.Register(std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1).IsOK());
|
||||
ASSERT_TRUE(
|
||||
graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CommonSubexpressionElimination>(),
|
||||
TransformerLevel::Level1));
|
||||
const ConfigOptions empty_config_options;
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1,
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
ASSERT_EQ(graph.GetAllInitializedTensors().size(), 1U);
|
||||
auto op_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
|
|
@ -575,12 +575,14 @@ TEST_F(GraphTransformationTests, ConstantFolding) {
|
|||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
ASSERT_EQ(op_to_count["Unsqueeze"], 2);
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
|
|
@ -595,11 +597,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
|
|||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
// assign all nodes to CUDA. the constant folding should override this to perform the constant folding on cpu
|
||||
for (auto& node : graph.Nodes()) {
|
||||
|
|
@ -624,11 +628,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) {
|
|||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 1);
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
// assign all nodes to CUDA. the constant folding should try folding the node on the CPU and fail, thus leaving the
|
||||
// EP as CUDA and not constant folding the node.
|
||||
|
|
@ -707,11 +712,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingSubgraph) {
|
|||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 2); // one in each subgraph
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
|
|
@ -731,14 +737,15 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithShapeToInitializer) {
|
|||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 3);
|
||||
|
||||
InlinedHashSet<std::string_view> compatible_eps;
|
||||
InlinedHashSet<std::string> excluded_initializers;
|
||||
excluded_initializers.insert("matmul_weight");
|
||||
InlinedHashSet<std::string> excluded_initializers = {"matmul_weight"};
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(),
|
||||
false /*skip_dequantize_linear*/,
|
||||
empty_config_options,
|
||||
compatible_eps,
|
||||
excluded_initializers),
|
||||
TransformerLevel::Level1));
|
||||
|
|
@ -763,11 +770,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithScalarShapeToInitializer) {
|
|||
|
||||
InlinedHashSet<std::string_view> compatible_eps;
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(),
|
||||
false /*skip_dequantize_linear*/,
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options,
|
||||
compatible_eps),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
|
|
@ -792,11 +799,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingForOpsWithMissingOptionalInputs)
|
|||
|
||||
InlinedHashSet<std::string_view> compatible_eps;
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(),
|
||||
false /*skip_dequantize_linear*/,
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options,
|
||||
compatible_eps),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
|
|
@ -965,11 +972,12 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst
|
|||
ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape
|
||||
ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
|
|
@ -988,10 +996,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) {
|
|||
ASSERT_TRUE(op_to_count["Shape"] == 4);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
const ConfigOptions empty_config_options;
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -1014,9 +1025,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) {
|
|||
ASSERT_EQ(op_to_count["Identity"], 1);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const ConfigOptions empty_config_options;
|
||||
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
|
||||
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
|
||||
TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ namespace test {
|
|||
static const std::string MODEL_FOLDER = "testdata/transform/";
|
||||
|
||||
TEST(OptimizerTest, Basic) {
|
||||
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
constexpr int tensor_dim = 10;
|
||||
|
|
@ -65,8 +66,7 @@ TEST(OptimizerTest, Basic) {
|
|||
nodes.push_back(&node);
|
||||
}
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
|
||||
graph.ModelPath(),
|
||||
|
|
@ -85,8 +85,10 @@ TEST(OptimizerTest, Basic) {
|
|||
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
|
||||
const logging::Logger& logger = DefaultLoggingManager().DefaultLogger();
|
||||
|
||||
const ConfigOptions empty_config_options;
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
auto kernel = info.CreateKernel(&node);
|
||||
auto kernel = info.CreateKernel(&node, empty_config_options);
|
||||
|
||||
// kernel can only be a nullptr if a CPU kernel implementation has been removed,
|
||||
// if that is the case, OpKernelContext instance construction will throw in the next step
|
||||
|
|
|
|||
|
|
@ -248,10 +248,9 @@ static common::Status CreateSubgraph(Graph& graph, RunOptions& options, const st
|
|||
auto status = graph.Resolve();
|
||||
|
||||
if (failure_message.empty()) {
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
EXPECT_STATUS_OK(status);
|
||||
} else {
|
||||
EXPECT_TRUE(!status.IsOK());
|
||||
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr(failure_message));
|
||||
EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(status, failure_message);
|
||||
}
|
||||
|
||||
return status;
|
||||
|
|
|
|||
|
|
@ -153,9 +153,8 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
|
|||
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
|
||||
|
||||
ASSERT_STATUS_OK(session->Load(ort_model_path));
|
||||
auto status = session->Initialize();
|
||||
ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels";
|
||||
ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes"));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session->Initialize(),
|
||||
"Unable to serialize model as it contains compiled nodes");
|
||||
}
|
||||
|
||||
// the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC
|
||||
|
|
@ -195,11 +194,10 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {
|
|||
output_names.push_back("Z");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
auto status = session.Run(feeds, output_names, &fetches);
|
||||
// Error message should come from the Conv implementation with the statically registered kernel
|
||||
ASSERT_THAT(status.ErrorMessage(),
|
||||
::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' "
|
||||
"Status Message: TODO: add NHWC implementation here."));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
|
||||
"Non-zero status code returned while running Conv node. Name:'Conv' "
|
||||
"Status Message: TODO: add NHWC implementation here.");
|
||||
}
|
||||
|
||||
TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
|
||||
|
|
@ -243,10 +241,9 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
|
|||
output_names.push_back("softmaxout_1");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
auto status = session.Run(feeds, output_names, &fetches);
|
||||
ASSERT_THAT(status.ErrorMessage(),
|
||||
::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' "
|
||||
"Status Message: TODO: add NHWC implementation here."));
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
|
||||
"Non-zero status code returned while running Conv node. Name:'Conv' "
|
||||
"Status Message: TODO: add NHWC implementation here.");
|
||||
}
|
||||
|
||||
// This test can be deprecated now as the code logic has been changed so the model is not applicable
|
||||
|
|
|
|||
|
|
@ -124,7 +124,8 @@ void KernelComputeTester::Run(std::unordered_set<int> strided_outputs) {
|
|||
outputs.emplace_back(output);
|
||||
}
|
||||
|
||||
auto kernel = info.CreateKernel(&node);
|
||||
static const ConfigOptions empty_config_options;
|
||||
auto kernel = info.CreateKernel(&node, empty_config_options);
|
||||
ASSERT_TRUE(kernel);
|
||||
|
||||
std::vector<int> fetch_mlvalue_idxs;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/common/status.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
// helpers to run a function and check the status, outputting any error if it fails.
|
||||
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
|
||||
|
|
@ -33,6 +34,20 @@
|
|||
EXPECT_FALSE(_tmp_status.IsOK()); \
|
||||
} while (false)
|
||||
|
||||
#define ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
|
||||
do { \
|
||||
Status _tmp_status = (function); \
|
||||
ASSERT_FALSE(_tmp_status.IsOK()); \
|
||||
ASSERT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
|
||||
} while (false)
|
||||
|
||||
#define EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
|
||||
do { \
|
||||
Status _tmp_status = (function); \
|
||||
EXPECT_FALSE(_tmp_status.IsOK()); \
|
||||
EXPECT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
|
||||
} while (false)
|
||||
|
||||
// Same helpers for public API OrtStatus. Get the 'api' instance using:
|
||||
// const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
#define ASSERT_ORTSTATUS_OK(api, function) \
|
||||
|
|
|
|||
|
|
@ -157,8 +157,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
transformers.emplace_back(std::make_unique<GeluApproximation>(compatible_eps));
|
||||
}
|
||||
InlinedHashSet<std::string> excluded_initializers(weights_to_train.begin(), weights_to_train.end());
|
||||
static const ConfigOptions empty_config_options;
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(
|
||||
execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers));
|
||||
execution_provider, false /*skip_dequantize_linear*/, empty_config_options, compatible_eps,
|
||||
excluded_initializers));
|
||||
transformers.emplace_back(std::make_unique<ReshapeFusion>(compatible_eps));
|
||||
// Put fine-grained optimizer (e.g. ShapeOptimizer) after ReshapeFusion to avoid it breaks the strong patterns
|
||||
// it defines. ReshapeFusion depends on subgraph pattern matching and do replacement accordingly, ShapeOptimizer
|
||||
|
|
|
|||
Loading…
Reference in a new issue