Make session configuration options available to kernels via OpKernelInfo (#18897)

### Description
<!-- Describe your changes. -->
Pass through the ConfigOptions from the session via OpKernelInfo so that
kernel behavior can be configured.

Initial usage would be to optionally enable a fast path for ARM64 bloat16 GEMM - see #17031
Other usages could be things like selected the exact implementations of the activation functions for RNN operators instead of the default approximations (e.g. use [sigmoid_exact instead of sigmoid](2d6e2e243d/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h (L379-L382)))

OpKernelInfo is already passing through things from the session state, and adding a new member of ConfigOptions
is the simpler update. It's also a more natural fit given it's providing state/info to the kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Scott McKay 2024-01-13 10:02:43 +10:00 committed by GitHub
parent a503561d0c
commit 8f2e57f5d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 246 additions and 162 deletions

View file

@ -28,7 +28,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& mlvalue_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const AllocatorMap& allocators = {});
const AllocatorMap& allocators,
const ConfigOptions& config_options);
OpKernelInfo(const OpKernelInfo& other);
@ -50,6 +51,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const AllocatorMap& GetAllocators() const { return allocators_; }
const ConfigOptions& GetConfigOptions() const { return config_options_; }
private:
ORT_DISALLOW_MOVE(OpKernelInfo);
ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
@ -64,6 +67,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const DataTransferManager& data_transfer_mgr_;
ProtoHelperNodeContext proto_helper_context_;
const AllocatorMap& allocators_;
const ConfigOptions& config_options_;
};
} // namespace onnxruntime

View file

@ -24,7 +24,8 @@ Status KernelRegistryManager::CreateKernel(const Node& node,
session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(),
session_state.GetDataTransferMgr(),
session_state.GetAllocators());
session_state.GetAllocators(),
session_state.GetSessionOptions().config_options);
return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out);
}

View file

@ -15,7 +15,8 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const AllocatorMap& allocators)
const AllocatorMap& allocators,
const ConfigOptions& config_options)
: OpNodeProtoHelper(&proto_helper_context_),
node_(node),
kernel_def_(kernel_def),
@ -24,15 +25,22 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
ort_value_name_idx_map_(ort_value_name_idx_map),
data_transfer_mgr_(data_transfer_mgr),
proto_helper_context_(node),
allocators_(allocators) {}
allocators_(allocators),
config_options_(config_options) {
}
OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_,
other.ort_value_name_idx_map_, other.data_transfer_mgr_, other.allocators_) {}
other.ort_value_name_idx_map_, other.data_transfer_mgr_,
other.allocators_, other.config_options_) {
}
AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const {
auto it = allocators_.find(execution_provider_->GetOrtDeviceByMemType(mem_type));
if (it != allocators_.end()) return it->second;
if (it != allocators_.end()) {
return it->second;
}
return nullptr;
}

View file

@ -18,10 +18,12 @@ namespace onnxruntime {
ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer("ConstantFolding", compatible_execution_providers),
skip_dequantize_linear_(skip_dequantize_linear),
config_options_(config_options),
excluded_initializers_(excluded_initializers),
execution_provider_(execution_provider) {
}
@ -250,12 +252,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// override the EP assigned to the node so that it will use the CPU kernel for Compute.
node->SetExecutionProviderType(kCpuExecutionProvider);
kernel = info.CreateKernel(node);
kernel = info.CreateKernel(node, config_options_);
// undo the EP change to the value that was assigned at graph partitioning time
node->SetExecutionProviderType(ep_type);
} else {
kernel = info.CreateKernel(node);
kernel = info.CreateKernel(node, config_options_);
}
// We currently constant fold using the CPU EP only.

View file

@ -24,6 +24,7 @@ class ConstantFolding : public GraphTransformer {
*/
ConstantFolding(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;
@ -31,6 +32,7 @@ class ConstantFolding : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool skip_dequantize_linear_;
const ConfigOptions& config_options_;
const InlinedHashSet<std::string> excluded_initializers_;
const IExecutionProvider& execution_provider_;
};

View file

@ -223,7 +223,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
transformers.emplace_back(std::make_unique<MatMulAddFusion>());
transformers.emplace_back(std::make_unique<ReshapeFusion>());
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(

View file

@ -128,26 +128,34 @@ static Status TryCreateKernel(const Node& node,
const OrtValueNameIdxMap& ort_value_name_idx_map,
FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
const ConfigOptions& config_options,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelCreateInfo* kernel_create_info = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
&kernel_create_info));
static const AllocatorMap dummy_allocators;
OpKernelInfo kernel_info(node,
*kernel_create_info->kernel_def,
execution_provider,
constant_initialized_tensors,
ort_value_name_idx_map,
data_transfer_mgr);
data_transfer_mgr,
dummy_allocators,
config_options);
return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel);
}
std::unique_ptr<const OpKernel> OptimizerExecutionFrame::Info::CreateKernel(const Node* node) const {
std::unique_ptr<const OpKernel>
OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOptions& config_options) const {
std::unique_ptr<OpKernel> op_kernel;
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
FuncManager func;
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
ort_value_name_idx_map_, func, data_transfer_mgr_,
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
op_kernel);
// Kernel found in the CPU kernel registry

View file

@ -27,11 +27,13 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
const Path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
Info(const std::vector<const Node*>& nodes,
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
const Path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);
~Info() = default;
const AllocatorPtr& GetAllocator() const {
@ -52,7 +54,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
return -1;
}
std::unique_ptr<const OpKernel> CreateKernel(const Node* node) const;
std::unique_ptr<const OpKernel> CreateKernel(const Node* node, const ConfigOptions& config_options) const;
// Check if an kernel create info can be found in the registry.
Status TryFindKernel(const Node* node, const KernelCreateInfo** out) const;

View file

@ -132,6 +132,7 @@ struct Logger;
struct Capture;
} // namespace logging
struct ComputeCapability;
struct ConfigOptions;
struct DataTransferManager;
struct IndexedSubGraph;
struct IndexedSubGraph_MetaDef;

View file

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <optional>
// Public wrappers around internal ort interfaces (currently)
#include "core/providers/shared_library/provider_host_api.h"
@ -426,6 +428,9 @@ struct ProviderHost {
virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;
// ConfigOptions
virtual std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0;
// ComputeCapability
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
@ -808,6 +813,7 @@ struct ProviderHost {
virtual uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) = 0;
virtual uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) = 0;
virtual const Node& OpKernelInfo__node(const OpKernelInfo* p) = 0;
virtual const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) = 0;
// SessionState
virtual const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) = 0;

View file

@ -335,6 +335,14 @@ struct DataTypeUtils final {
} // namespace Utils
struct ConfigOptions final {
std::optional<std::string> GetConfigEntry(const std::string& config_key) const {
return g_host->ConfigOptions__GetConfigEntry(this, config_key);
}
PROVIDER_DISALLOW_ALL(ConfigOptions)
};
struct ComputeCapability final {
static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
@ -901,6 +909,8 @@ struct OpKernelInfo final {
const Node& node() const noexcept { return g_host->OpKernelInfo__node(this); }
const ConfigOptions& GetConfigOptions() const { return g_host->OpKernelInfo__GetConfigOptions(this); }
OpKernelInfo() = delete;
OpKernelInfo(const OpKernelInfo&) = delete;
void operator=(const OpKernelInfo&) = delete;

View file

@ -6,6 +6,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/allocator_utils.h"
#include "core/framework/config_options.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_types.h"
#include "core/framework/data_transfer_manager.h"
@ -529,6 +530,11 @@ struct ProviderHostImpl : ProviderHost {
const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; }
// ConfigOptions (wrapped)
std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override {
return p->GetConfigEntry(config_key);
}
// ComputeCapability (wrapped)
std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }
@ -934,6 +940,7 @@ struct ProviderHostImpl : ProviderHost {
uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) override { return p->GetInputCount(); }
uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) override { return p->GetOutputCount(); }
const Node& OpKernelInfo__node(const OpKernelInfo* p) override { return p->node(); }
const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) override { return p->GetConfigOptions(); }
// SessionState (wrapped)
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }

View file

@ -421,7 +421,10 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info,
static const OrtValueNameIdxMap kEmptyNameMap;
OpKernelInfo tmp_kernel_info(*node_ptr.get(), *kernel_def, *ep, kEmptyValueMap, kEmptyNameMap,
kernel_info->GetDataTransferManager(), kernel_info->GetAllocators());
kernel_info->GetDataTransferManager(),
kernel_info->GetAllocators(),
kernel_info->GetConfigOptions());
std::unique_ptr<onnxruntime::OpKernel> op_kernel;
auto& node_repo = NodeRepo::GetInstance();

View file

@ -254,7 +254,7 @@ class PlannerTest : public ::testing::Test {
ASSERT_NE(ep, nullptr);
auto info = std::make_unique<OpKernelInfo>(
*p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(),
state_->GetDataTransferMgr());
state_->GetDataTransferMgr(), state_->GetAllocators(), state_->GetSessionOptions().config_options);
op_kernel_infos_.push_back(std::move(info));
const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{};

View file

@ -82,6 +82,11 @@ ProviderInfo_ROCM& GetProviderInfo_ROCM();
class FuseAdd : public OpKernel {
public:
explicit FuseAdd(const OpKernelInfo& info) : OpKernel(info) {
// logic for testing that a session options config value can be read here
auto test_throw_in_ctor = info.GetConfigOptions().GetConfigEntry("ThrowInKernelCtor");
if (test_throw_in_ctor == "1") {
ORT_THROW("Test exception in ctor");
};
}
Status Compute(OpKernelContext* context) const override {
@ -96,6 +101,7 @@ class FuseAdd : public OpKernel {
return Status::OK();
}
};
constexpr const char* kFuseTest = "FuseTest";
constexpr const char* kFuseExecutionProvider = "FuseExecutionProvider";
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kFuseExecutionProvider, kFuseTest, 1, FuseAdd);
@ -1263,28 +1269,22 @@ TEST(InferenceSessionTests, TestOptionalInputs) {
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
}
// required, optional and invalid input
status = RunOptionalInputTest(true, true, true, version, sess_env);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(true, true, true, version, sess_env),
"Invalid input name");
// missing required
status = RunOptionalInputTest(false, true, false, version, sess_env);
ASSERT_FALSE(status.IsOK());
if (version == 3) {
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name"));
} else {
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:"));
}
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(false, true, false, version, sess_env),
(version == 3 ? "Invalid input name" : "Missing Input:"));
}
}
TEST(ExecutionProviderTest, FunctionTest) {
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
static void CreateFuseOpModel(const std::string& model_file_name) {
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
// FLOAT tensor.
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
@ -1307,18 +1307,19 @@ TEST(ExecutionProviderTest, FunctionTest) {
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(graph.Resolve());
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
}
TEST(ExecutionProviderTest, FunctionTest) {
std::string model_file_name = "execution_provider_test_graph.onnx";
status = onnxruntime::Model::Save(model, model_file_name);
CreateFuseOpModel(model_file_name);
SessionOptions so;
so.session_logid = "ExecutionProviderTest.FunctionTest";
InferenceSession session_object{so, GetEnvironment()};
status = session_object.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_file_name));
ASSERT_STATUS_OK(session.Initialize());
RunOptions run_options;
run_options.run_tag = so.session_logid;
@ -1329,11 +1330,14 @@ TEST(ExecutionProviderTest, FunctionTest) {
std::vector<int64_t> dims_mul_x = {3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
OrtValue ml_value_x;
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_x);
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
&ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_y);
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
&ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_z);
CreateMLValue<float>(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
&ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
@ -1349,67 +1353,33 @@ TEST(ExecutionProviderTest, FunctionTest) {
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
// Now run
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
InferenceSession session_object_2{so, GetEnvironment()};
ASSERT_STATUS_OK(
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
ASSERT_STATUS_OK(session_object_2.Load(model_file_name));
ASSERT_STATUS_OK(session_object_2.Initialize());
ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches));
InferenceSession session2{so, GetEnvironment()};
ASSERT_STATUS_OK(session2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
ASSERT_STATUS_OK(session2.Load(model_file_name));
ASSERT_STATUS_OK(session2.Initialize());
ASSERT_STATUS_OK(session2.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}
TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
// FLOAT tensor.
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
inputs.push_back(&input_arg_1);
inputs.push_back(&input_arg_2);
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
outputs.push_back(&output_arg);
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
inputs.clear();
inputs.push_back(&output_arg);
inputs.push_back(&input_arg_3);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
std::string model_file_name = "fused_node_shape_inference_test_graph.onnx";
status = onnxruntime::Model::Save(model, model_file_name);
CreateFuseOpModel(model_file_name);
SessionOptions so;
so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest";
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(
session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
status = session.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session.Initialize();
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
ASSERT_STATUS_OK(session.Load(model_file_name));
ASSERT_STATUS_OK(session.Initialize());
Graph& fused_graph = session.GetMutableGraph();
ASSERT_TRUE(fused_graph.NumberOfNodes() == 1);
ASSERT_EQ(fused_graph.NumberOfNodes(), 1);
auto& fused_node = *fused_graph.Nodes().begin();
ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused);
ASSERT_EQ(fused_node.NodeType(), Node::Type::Fused);
ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function());
// Clear shape inference data from output node to verify that assigned inference function is called
@ -1419,7 +1389,25 @@ TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) {
ASSERT_STATUS_OK(fused_graph.Resolve());
ASSERT_TRUE(fused_node_output.Shape() != nullptr);
ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()) == utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape()));
ASSERT_EQ(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()), TensorShape({3, 2}));
}
TEST(ExecutionProviderTest, OpKernelInfoCanReadConfigOptions) {
std::string model_file_name = "OpKernelInfoCanReadConfigOptions.onnx";
CreateFuseOpModel(model_file_name);
SessionOptions so;
so.session_logid = "ExecutionProviderTest.OpKernelInfoCanReadConfigOptions";
// add a config key that if read causes the Fuse op kernel to throw in the ctor. this is just to test the value is passed
// through in the simplest way, as the kernel is constructed in InferenceSession::Intialize so we don't need to
// actually run the model.
ASSERT_STATUS_OK(so.config_options.AddConfigEntry("ThrowInKernelCtor", "1"));
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>()));
ASSERT_STATUS_OK(session.Load(model_file_name));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Initialize(), "Test exception in ctor");
}
TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {

View file

@ -84,9 +84,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(),
s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr());
unique_ptr<TestOpKernel> p_kernel;
p_kernel.reset(new TestOpKernel(p_info));
s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr(), s.GetAllocators(),
s.GetSessionOptions().config_options);
std::unique_ptr<TestOpKernel> p_kernel = std::make_unique<TestOpKernel>(p_info);
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
std::cout << "node_idx: " << node.Index() << std::endl;

View file

@ -1503,10 +1503,8 @@ TEST_F(GraphTest, ShapeInferenceErrorHandling) {
graph.AddNode("node_1", "ShapeInferenceThrowsOp", "node 1", {&input_arg1}, {&output_arg1});
auto status = graph.Resolve();
EXPECT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Node (node_1) Op (ShapeInferenceThrowsOp) "
"[ShapeInferenceError] try harder"));
EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(graph.Resolve(),
"Node (node_1) Op (ShapeInferenceThrowsOp) [ShapeInferenceError] try harder");
}
TEST_F(GraphTest, AddTensorAttribute) {
@ -2024,10 +2022,9 @@ TEST_F(GraphTest, LoadModelMissingInput) {
SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2});
std::shared_ptr<Model> model;
Status st = Model::Load(std::move(m), model, nullptr, *logger_);
ASSERT_FALSE(st.IsOK());
ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, "
"initializer, or output of a previous node."));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(Model::Load(std::move(m), model, nullptr, *logger_),
"Invalid model. Node input 'y' is not a graph input, "
"initializer, or output of a previous node.");
}
// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph.

View file

@ -69,7 +69,18 @@ struct KernelAndDef {
.SetDomain(domain)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.Build();
OpKernelInfo info(main_node, *out.def, *out.a, {}, {}, {});
// these usually come from the session state. OpKernelInfo stores references to them so we need a valid backing
// instance even though we don't use them in this test.
static const std::unordered_map<int, OrtValue> constant_initialized_tensors;
static const OrtValueNameIdxMap mlvalue_name_idx_map;
static const DataTransferManager data_transfer_mgr;
static const AllocatorMap allocators;
static const ConfigOptions config_options;
OpKernelInfo info(main_node, *out.def, *out.a,
constant_initialized_tensors, mlvalue_name_idx_map, data_transfer_mgr, allocators,
config_options);
out.kernel = std::make_unique<KernelType>(info);
return out;
}

View file

@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/framework/test_utils.h"
#include "test/test_environment.h"
#include "core/graph/model.h"
#include "core/optimizer/common_subexpression_elimination.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "test/framework/test_utils.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#ifdef ENABLE_TRAINING
#include "orttraining/core/optimizer/graph_transformer_utils.h"
@ -272,20 +273,21 @@ TEST(CseTests, MergedValueAndGraphOutputAreOutputsOfSameNode) {
TEST(CseTests, MergeConstants) {
auto model_uri = ORT_TSTR("testdata/transform/cse/cse_merge_constants.onnx");
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr,
DefaultLoggingManager().DefaultLogger())
.IsOK());
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()));
Graph& graph = model->MainGraph();
GraphTransformerManager graph_transformation_mgr(1);
// In current implementation, equal constants are not merged. So CSE must precede constant folding, otherwise we end up
// with multiple copies of the same constant.
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_TRUE(
graph_transformation_mgr.Register(std::make_unique<CommonSubexpressionElimination>(), TransformerLevel::Level1).IsOK());
ASSERT_TRUE(
graph_transformation_mgr.Register(std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1).IsOK());
ASSERT_TRUE(
graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CommonSubexpressionElimination>(),
TransformerLevel::Level1));
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1,
DefaultLoggingManager().DefaultLogger()));
ASSERT_EQ(graph.GetAllInitializedTensors().size(), 1U);
auto op_count = CountOpsInGraph(graph);

View file

@ -575,12 +575,14 @@ TEST_F(GraphTransformationTests, ConstantFolding) {
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_EQ(op_to_count["Unsqueeze"], 2);
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
@ -595,11 +597,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
// assign all nodes to CUDA. the constant folding should override this to perform the constant folding on cpu
for (auto& node : graph.Nodes()) {
@ -624,11 +628,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) {
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Mul"] == 1);
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
// assign all nodes to CUDA. the constant folding should try folding the node on the CPU and fail, thus leaving the
// EP as CUDA and not constant folding the node.
@ -707,11 +712,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingSubgraph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 2); // one in each subgraph
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
@ -731,14 +737,15 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithShapeToInitializer) {
ASSERT_TRUE(op_to_count["Unsqueeze"] == 3);
InlinedHashSet<std::string_view> compatible_eps;
InlinedHashSet<std::string> excluded_initializers;
excluded_initializers.insert("matmul_weight");
InlinedHashSet<std::string> excluded_initializers = {"matmul_weight"};
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
const ConfigOptions empty_config_options;
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(),
false /*skip_dequantize_linear*/,
empty_config_options,
compatible_eps,
excluded_initializers),
TransformerLevel::Level1));
@ -763,11 +770,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithScalarShapeToInitializer) {
InlinedHashSet<std::string_view> compatible_eps;
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
const ConfigOptions empty_config_options;
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(),
false /*skip_dequantize_linear*/,
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options,
compatible_eps),
TransformerLevel::Level1));
@ -792,11 +799,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingForOpsWithMissingOptionalInputs)
InlinedHashSet<std::string_view> compatible_eps;
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
const ConfigOptions empty_config_options;
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(),
false /*skip_dequantize_linear*/,
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options,
compatible_eps),
TransformerLevel::Level1));
@ -965,11 +972,12 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst
ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape
ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
@ -988,10 +996,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) {
ASSERT_TRUE(op_to_count["Shape"] == 4);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
const ConfigOptions empty_config_options;
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
@ -1014,9 +1025,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) {
ASSERT_EQ(op_to_count["Identity"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const ConfigOptions empty_config_options;
std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/, empty_config_options),
TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);

View file

@ -27,7 +27,8 @@ namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
TEST(OptimizerTest, Basic) {
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
constexpr int tensor_dim = 10;
@ -65,8 +66,7 @@ TEST(OptimizerTest, Basic) {
nodes.push_back(&node);
}
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
#if !defined(DISABLE_SPARSE_TENSORS)
OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set,
graph.ModelPath(),
@ -85,8 +85,10 @@ TEST(OptimizerTest, Basic) {
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
const logging::Logger& logger = DefaultLoggingManager().DefaultLogger();
const ConfigOptions empty_config_options;
for (auto& node : graph.Nodes()) {
auto kernel = info.CreateKernel(&node);
auto kernel = info.CreateKernel(&node, empty_config_options);
// kernel can only be a nullptr if a CPU kernel implementation has been removed,
// if that is the case, OpKernelContext instance construction will throw in the next step

View file

@ -248,10 +248,9 @@ static common::Status CreateSubgraph(Graph& graph, RunOptions& options, const st
auto status = graph.Resolve();
if (failure_message.empty()) {
EXPECT_EQ(status, Status::OK());
EXPECT_STATUS_OK(status);
} else {
EXPECT_TRUE(!status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr(failure_message));
EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(status, failure_message);
}
return status;

View file

@ -153,9 +153,8 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) {
std::make_unique<InternalTestingExecutionProvider>(supported_ops)));
ASSERT_STATUS_OK(session->Load(ort_model_path));
auto status = session->Initialize();
ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels";
ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes"));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session->Initialize(),
"Unable to serialize model as it contains compiled nodes");
}
// the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC
@ -195,11 +194,10 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {
output_names.push_back("Z");
std::vector<OrtValue> fetches;
auto status = session.Run(feeds, output_names, &fetches);
// Error message should come from the Conv implementation with the statically registered kernel
ASSERT_THAT(status.ErrorMessage(),
::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here."));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here.");
}
TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
@ -243,10 +241,9 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
output_names.push_back("softmaxout_1");
std::vector<OrtValue> fetches;
auto status = session.Run(feeds, output_names, &fetches);
ASSERT_THAT(status.ErrorMessage(),
::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here."));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here.");
}
// This test can be deprecated now as the code logic has been changed so the model is not applicable

View file

@ -124,7 +124,8 @@ void KernelComputeTester::Run(std::unordered_set<int> strided_outputs) {
outputs.emplace_back(output);
}
auto kernel = info.CreateKernel(&node);
static const ConfigOptions empty_config_options;
auto kernel = info.CreateKernel(&node, empty_config_options);
ASSERT_TRUE(kernel);
std::vector<int> fetch_mlvalue_idxs;

View file

@ -6,6 +6,7 @@
#include "core/common/status.h"
#include "core/session/onnxruntime_c_api.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
// helpers to run a function and check the status, outputting any error if it fails.
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
@ -33,6 +34,20 @@
EXPECT_FALSE(_tmp_status.IsOK()); \
} while (false)
#define ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
do { \
Status _tmp_status = (function); \
ASSERT_FALSE(_tmp_status.IsOK()); \
ASSERT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
} while (false)
#define EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
do { \
Status _tmp_status = (function); \
EXPECT_FALSE(_tmp_status.IsOK()); \
EXPECT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
} while (false)
// Same helpers for public API OrtStatus. Get the 'api' instance using:
// const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
#define ASSERT_ORTSTATUS_OK(api, function) \

View file

@ -157,8 +157,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<GeluApproximation>(compatible_eps));
}
InlinedHashSet<std::string> excluded_initializers(weights_to_train.begin(), weights_to_train.end());
static const ConfigOptions empty_config_options;
transformers.emplace_back(std::make_unique<ConstantFolding>(
execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers));
execution_provider, false /*skip_dequantize_linear*/, empty_config_options, compatible_eps,
excluded_initializers));
transformers.emplace_back(std::make_unique<ReshapeFusion>(compatible_eps));
// Put fine-grained optimizer (e.g. ShapeOptimizer) after ReshapeFusion to avoid it breaks the strong patterns
// it defines. ReshapeFusion depends on subgraph pattern matching and do replacement accordingly, ShapeOptimizer