onnxruntime/onnxruntime/test/util/include/asserts.h
Scott McKay 8f2e57f5d0
Make session configuration options available to kernels via OpKernelInfo (#18897)
### Description
<!-- Describe your changes. -->
Pass through the ConfigOptions from the session via OpKernelInfo so that
kernel behavior can be configured.

Initial usage would be to optionally enable a fast path for ARM64 bloat16 GEMM - see #17031
Other usages could be things like selected the exact implementations of the activation functions for RNN operators instead of the default approximations (e.g. use [sigmoid_exact instead of sigmoid](2d6e2e243d/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h (L379-L382)))

OpKernelInfo is already passing through things from the session state, and adding a new member of ConfigOptions
is the simpler update. It's also a more natural fit given it's providing state/info to the kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-01-13 10:02:43 +10:00

79 lines
3.5 KiB
C

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/status.h"
#include "core/session/onnxruntime_c_api.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
// helpers to run a function and check the status, outputting any error if it fails.
// note: wrapped in do{} while(false) so the _tmp_status variable has limited scope
#define ASSERT_STATUS_OK(function) \
do { \
Status _tmp_status = (function); \
ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
} while (false)
#define EXPECT_STATUS_OK(function) \
do { \
Status _tmp_status = (function); \
EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status; \
} while (false)
#define ASSERT_STATUS_NOT_OK(function) \
do { \
Status _tmp_status = (function); \
ASSERT_FALSE(_tmp_status.IsOK()); \
} while (false)
#define EXPECT_STATUS_NOT_OK(function) \
do { \
Status _tmp_status = (function); \
EXPECT_FALSE(_tmp_status.IsOK()); \
} while (false)
#define ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
do { \
Status _tmp_status = (function); \
ASSERT_FALSE(_tmp_status.IsOK()); \
ASSERT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
} while (false)
#define EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \
do { \
Status _tmp_status = (function); \
EXPECT_FALSE(_tmp_status.IsOK()); \
EXPECT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \
} while (false)
// Same helpers for public API OrtStatus. Get the 'api' instance using:
// const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
#define ASSERT_ORTSTATUS_OK(api, function) \
do { \
OrtStatusPtr _tmp_status = (api->function); \
ASSERT_EQ(_tmp_status, nullptr) << api->GetErrorMessage(_tmp_status); \
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
} while (false)
#define EXPECT_ORTSTATUS_OK(api, function) \
do { \
OrtStatusPtr _tmp_status = (api->function); \
EXPECT_EQ(_tmp_status, nullptr) << api->GetErrorMessage(_tmp_status); \
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
} while (false)
#define ASSERT_ORTSTATUS_NOT_OK(api, function) \
do { \
OrtStatusPtr _tmp_status = (api->function); \
ASSERT_NE(_tmp_status, nullptr); \
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
} while (false)
#define EXPECT_ORTSTATUS_NOT_OK(api, function) \
do { \
OrtStatusPtr _tmp_status = (api->function); \
EXPECT_NE(_tmp_status, nullptr); \
if (_tmp_status) api->ReleaseStatus(_tmp_status); \
} while (false)