mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
### Description
<!-- Describe your changes. -->
Pass through the ConfigOptions from the session via OpKernelInfo so that
kernel behavior can be configured.
Initial usage would be to optionally enable a fast path for ARM64 bloat16 GEMM - see #17031
Other usages could be things like selected the exact implementations of the activation functions for RNN operators instead of the default approximations (e.g. use [sigmoid_exact instead of sigmoid](2d6e2e243d/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h (L379-L382)))
OpKernelInfo is already passing through things from the session state, and adding a new member of ConfigOptions
is the simpler update. It's also a more natural fit given it's providing state/info to the kernel.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
79 lines
3.5 KiB
C
79 lines
3.5 KiB
C
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "core/common/status.h"
|
|
#include "core/session/onnxruntime_c_api.h"
|
|
#include "gtest/gtest.h"
|
|
#include "gmock/gmock.h"
|
|
|
|
// helpers to run a function and check the status, outputting any error if it fails.
|
|
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
|
|
#define ASSERT_STATUS_OK(function) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
|
|
} while (false)
|
|
|
|
#define EXPECT_STATUS_OK(function) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
|
|
} while (false)
|
|
|
|
#define ASSERT_STATUS_NOT_OK(function) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
ASSERT_FALSE(_tmp_status.IsOK()); \
|
|
} while (false)
|
|
|
|
#define EXPECT_STATUS_NOT_OK(function) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
EXPECT_FALSE(_tmp_status.IsOK()); \
|
|
} while (false)
|
|
|
|
#define ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
ASSERT_FALSE(_tmp_status.IsOK()); \
|
|
ASSERT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
|
|
} while (false)
|
|
|
|
#define EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
|
|
do { \
|
|
Status _tmp_status = (function); \
|
|
EXPECT_FALSE(_tmp_status.IsOK()); \
|
|
EXPECT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
|
|
} while (false)
|
|
|
|
// Same helpers for public API OrtStatus. Get the 'api' instance using:
|
|
// const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
|
#define ASSERT_ORTSTATUS_OK(api, function) \
|
|
do { \
|
|
OrtStatusPtr _tmp_status = (api->function); \
|
|
ASSERT_EQ(_tmp_status, nullptr) << api->GetErrorMessage(_tmp_status); \
|
|
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
|
|
} while (false)
|
|
|
|
#define EXPECT_ORTSTATUS_OK(api, function) \
|
|
do { \
|
|
OrtStatusPtr _tmp_status = (api->function); \
|
|
EXPECT_EQ(_tmp_status, nullptr) << api->GetErrorMessage(_tmp_status); \
|
|
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
|
|
} while (false)
|
|
|
|
#define ASSERT_ORTSTATUS_NOT_OK(api, function) \
|
|
do { \
|
|
OrtStatusPtr _tmp_status = (api->function); \
|
|
ASSERT_NE(_tmp_status, nullptr); \
|
|
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
|
|
} while (false)
|
|
|
|
#define EXPECT_ORTSTATUS_NOT_OK(api, function) \
|
|
do { \
|
|
OrtStatusPtr _tmp_status = (api->function); \
|
|
EXPECT_NE(_tmp_status, nullptr); \
|
|
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
|
|
} while (false)
|